diff --git a/.clang-tidy b/.clang-tidy index 924095b4def280..868a22c2596029 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -20,7 +20,7 @@ bugprone-integer-division, bugprone-misplaced-widening-cast, -bugprone-move-forwarding-reference, -bugprone-multiple-statement-macro, --bugprone-narrowing-conversions, +bugprone-narrowing-conversions, -bugprone-not-null-terminated-result, -bugprone-parent-virtual-call, -bugprone-posix-return, @@ -155,7 +155,7 @@ cppcoreguidelines-avoid-c-arrays, -cppcoreguidelines-avoid-goto, cppcoreguidelines-c-copy-assignment-signature, cppcoreguidelines-explicit-virtual-functions, --cppcoreguidelines-init-variables, +cppcoreguidelines-init-variables, cppcoreguidelines-narrowing-conversions, cppcoreguidelines-no-malloc, -cppcoreguidelines-pro-type-const-cast, @@ -189,12 +189,12 @@ modernize-use-override, modernize-use-transparent-functors, -modernize-use-uncaught-exceptions, performance-faster-string-find, --performance-for-range-copy, +performance-for-range-copy, -performance-implicit-conversion-in-loop, -performance-inefficient-algorithm, performance-inefficient-string-concatenation, -performance-inefficient-vector-operation, --performance-move-const-arg, +performance-move-const-arg, -performance-move-constructor-init, -performance-no-automatic-move, performance-noexcept-move-constructor, diff --git a/.flake8 b/.flake8 index d9585ef248701d..91137a006d0885 100644 --- a/.flake8 +++ b/.flake8 @@ -26,6 +26,9 @@ per-file-ignores = # These files need tabs for testing. test/dygraph_to_static/test_error.py:E101,W191 + # Ignore compare with True in sot unittest + test/sot/test_dup_top.py:E712 + # temp ignore base directory python/paddle/base/*: E712, diff --git a/.gitmodules b/.gitmodules index 1fb3d67c6f27ca..8b06f4fb771cbb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -106,3 +106,7 @@ path = third_party/jitify url = https://github.com/NVIDIA/jitify.git ignore = dirty +[submodule "third_party/cccl"] + path = third_party/cccl + url = https://github.com/NVIDIA/cccl.git + ignore = dirty diff --git a/cmake/external/cccl.cmake b/cmake/external/cccl.cmake new file mode 100755 index 00000000000000..c4185bd41a2da7 --- /dev/null +++ b/cmake/external/cccl.cmake @@ -0,0 +1,31 @@ +include(ExternalProject) + +set(CCCL_PATH + "${THIRD_PARTY_PATH}/cccl" + CACHE STRING "A path setting for external_cccl path.") +set(CCCL_PREFIX_DIR ${CCCL_PATH}) +set(CCCL_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/cccl) + +# The latest commit has bugs in windows, so we set a fix commit. +set(CCCL_TAG 1f6e4bcae0fbf1bbed87f88544d8d2161c490fc1) +execute_process(COMMAND git --git-dir=${CCCL_SOURCE_DIR}/.git + --work-tree=${CCCL_SOURCE_DIR} checkout ${CCCL_TAG}) + +set(CCCL_INCLUDE_DIR ${CCCL_SOURCE_DIR}) +message("CCCL_INCLUDE_DIR is ${CCCL_INCLUDE_DIR}") +include_directories(${CCCL_INCLUDE_DIR}) + +ExternalProject_Add( + extern_cccl + ${EXTERNAL_PROJECT_LOG_ARGS} + SOURCE_DIR ${CCCL_SOURCE_DIR} + PREFIX ${CCCL_PREFIX_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "") + +add_library(cccl INTERFACE) + +add_dependencies(cccl extern_cccl) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 267f8d733cbd41..f2ef9fd845434b 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -19,12 +19,16 @@ set(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) set(CBLAS_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/openblas) set(CBLAS_TAG v0.3.7) -# OpenBLAS support Raptor Lake from v0.3.22 -if(UNIX - AND NOT APPLE - AND NOT WITH_ROCM +# Why use v0.3.18? The IDG business line encountered a random openblas error, +# which can be resolved after upgrading openblas. +# And why compile when gcc>8.2? Please refer to +# https://github.com/spack/spack/issues/19932#issuecomment-733452619 +# v0.3.18 only support gcc>=8.3 or gcc>=7.4 +if((CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 8.2 AND NOT WITH_XPU) - set(CBLAS_TAG v0.3.23) + # We only compile with openblas 0.3.18 when gcc >= 8.3 + set(CBLAS_TAG v0.3.18) endif() if(APPLE AND WITH_ARM) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 9ed3d53ccdc2b9..9f4ffd23a57e1c 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -1345,6 +1345,9 @@ function(math_library TARGET) if(WITH_GPU) if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) list(APPEND math_common_deps cub) + elseif(${CMAKE_CUDA_COMPILER_VERSION} EQUAL 12.0 + OR ${CMAKE_CUDA_COMPILER_VERSION} GREATER 12.0) + list(APPEND math_common_deps cccl) else() list(APPEND math_common_deps) endif() diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index f7a6e9a696b70c..4134b31a966ed5 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -247,6 +247,14 @@ if(NOT DEFINED WITH_MKLDNN) endif() endif() +if(WIN32) + if(MSVC) + if(MSVC_VERSION LESS 1920) + set(WITH_MKLDNN OFF) + endif() + endif() +endif() + if(WIN32 OR APPLE OR NOT WITH_GPU @@ -375,6 +383,10 @@ if(WITH_GPU) if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) include(external/cub) # download cub list(APPEND third_party_deps extern_cub) + elseif(${CMAKE_CUDA_COMPILER_VERSION} EQUAL 12.0 + OR ${CMAKE_CUDA_COMPILER_VERSION} GREATER 12.0) + include(external/cccl) + list(APPEND third_party_deps extern_cccl) endif() set(URL "https://paddlepaddledeps.bj.bcebos.com/externalErrorMsg_20210928.tar.gz" diff --git a/paddle/cinn/backends/compiler.cc b/paddle/cinn/backends/compiler.cc index 0a64b24712f489..f63869730a11f8 100644 --- a/paddle/cinn/backends/compiler.cc +++ b/paddle/cinn/backends/compiler.cc @@ -304,6 +304,8 @@ void Compiler::CompileCudaModule(const Module& module, auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name); CHECK(fn_kernel); + fn_ptr_.push_back(reinterpret_cast(fn_kernel)); + symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast(fn_kernel)); } diff --git a/paddle/cinn/backends/compiler.h b/paddle/cinn/backends/compiler.h index a468193d4d85a6..f269b00492a420 100644 --- a/paddle/cinn/backends/compiler.h +++ b/paddle/cinn/backends/compiler.h @@ -121,6 +121,8 @@ class Compiler final { */ void* Lookup(absl::string_view fn_name); + std::vector GetFnPtr() const { return fn_ptr_; } + private: void CompileCudaModule(const ir::Module& module, const std::string& code = ""); @@ -136,6 +138,7 @@ class Compiler final { Target target_; std::unique_ptr engine_; + std::vector fn_ptr_; #ifdef CINN_WITH_CUDA std::unique_ptr cuda_module_; #endif diff --git a/paddle/cinn/hlir/dialect/operator/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/CMakeLists.txt index dd1b708ce9fe44..570058329d0d39 100644 --- a/paddle/cinn/hlir/dialect/operator/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(ir) +add_subdirectory(transforms) diff --git a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index 1a5857fd2cfe20..e831bc7114f95c 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -35,6 +35,7 @@ if(NOT CINN_ONLY) COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_parsed_yaml_file} --op_yaml_path ${cinn_op_yaml_file} --output_path ${cinn_op_parsed_yaml_file} + DEPENDS ${cinn_op_gen_parsed_yaml_file} ${cinn_op_yaml_file} VERBATIM) add_custom_command( diff --git a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h index 99e12a3d13ab45..724aed031165d6 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h +++ b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h @@ -77,5 +77,27 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage { ParamKey data_; }; +struct JITInfoAttributeStorage : public pir::AttributeStorage { + using ParamKey = cinn::hlir::framework::newir::CUDAJITInfo; + explicit JITInfoAttributeStorage(const ParamKey& key) : data_(key) {} + + static JITInfoAttributeStorage* Construct(const ParamKey& key) { + return new JITInfoAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey& key) { + return std::hash()(*(reinterpret_cast(key.fn_ptr))); + } + + bool operator==(const ParamKey& key) const { + return data_.fn_ptr == key.fn_ptr; + } + + const ParamKey& GetAsKey() const { return data_; } + + private: + ParamKey data_; +}; + } // namespace dialect } // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index db81b53a16f968..3a4ebb63679f39 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -44,7 +44,7 @@ std::vector GroupOp::ops() { inner_block->end()); } -void GroupOp::Verify() {} +void GroupOp::VerifySig() {} void GroupOp::Print(pir::IrPrinter &printer) { auto &os = printer.os; diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 9d469d9f776c4b..39d433790be78f 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -36,7 +36,7 @@ class GroupOp : public pir::Op { pir::Block *block(); std::vector ops(); - void Verify(); + void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT }; diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc index 554d7357af970e..43d7a79f03de48 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc @@ -19,7 +19,13 @@ namespace dialect { const GroupInfo &GroupInfoAttribute::data() const { return storage()->GetAsKey(); } + +const cinn::hlir::framework::newir::CUDAJITInfo &CUDAJITInfoAttribute::data() + const { + return storage()->GetAsKey(); +} } // namespace dialect } // namespace cinn IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute) diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h index 6e92b45002785a..21724e7e3f6c9b 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h +++ b/paddle/cinn/hlir/dialect/operator/ir/op_attribute.h @@ -33,7 +33,22 @@ class GroupInfoAttribute : public pir::Attribute { const GroupInfo& data() const; }; +class CUDAJITInfoAttribute : public pir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CUDAJITInfoAttribute, + JITInfoAttributeStorage); + + bool operator<(const CUDAJITInfoAttribute& right) const { + return storage() < right.storage(); + } + + const cinn::hlir::framework::newir::CUDAJITInfo& data() const; +}; + } // namespace dialect } // namespace cinn IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute) +IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute) diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc index 6d2f0409f24e96..11ccd77bb109d0 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc @@ -39,20 +39,31 @@ void OperatorDialect::initialize() { >(); RegisterOp(); RegisterAttribute(); + RegisterAttribute(); } void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {} void OperatorDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { - os << "(" << attr.dialect().name(); - os << '.'; - if (auto group_info_attr = attr.dyn_cast()) { - const GroupInfo &data = group_info_attr.data(); - os << "GroupInfo)" - << "[" << data.fn_name << "]"; + if (attr.isa()) { + os << "(" << attr.dialect().name(); + os << '.'; + if (auto group_info_attr = attr.dyn_cast()) { + const GroupInfo &data = group_info_attr.data(); + os << "GroupInfo)" + << "[" << data.fn_name << "]"; + } + { os << "<#AttrNotImplemented>"; } + } else if (attr.isa()) { + auto cuda_jit_info = attr.dyn_cast(); + + os << "(" << cuda_jit_info.data().fn_ptr; + os << ')'; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "cinn dialect only support GrupInfo and CUDAJITInfo")); } - { os << "<#AttrNotImplemented>"; } } void OperatorDialect::PrintOperation(pir::Operation *op, diff --git a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml index 096d2c4e652b17..9f14c6e4066611 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/ops.yaml +++ b/paddle/cinn/hlir/dialect/operator/ir/ops.yaml @@ -1,8 +1,25 @@ -- op : add - args : (Tensor x, Tensor y) +- op : broadcast + args : (Tensor x, int64_t[] broadcast_axes, int64_t[] out_shape) output : Tensor(out) infer_meta : - func : ElementwiseInferMeta + func : CINNBroadcastInferMeta + param : [x, broadcast_axes, out_shape] kernel : - func : add - inplace : (x -> out) + func : expand + param : [x, broadcast_axes] + +- op : reduce_max + args : (Tensor x, int64_t[] axis, bool keep_dim) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : frobenius_norm + +- op : reduce_sum + args : (Tensor x, int64_t[] axis, bool keep_dim) + output : Tensor(out) + infer_meta : + func : ReduceInferMeta + kernel : + func : frobenius_norm diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt new file mode 100644 index 00000000000000..770e78d191e3dc --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -0,0 +1,10 @@ +if(NOT CINN_ONLY) + cinn_cc_library( + op_with_group_merge_pass + SRCS + group_with_group_merge_pass.cc + op_with_group_merge_pass.cc + tensor_node.cc + DEPS + pd_op_dialect) +endif() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc new file mode 100644 index 00000000000000..e9c165bbcec523 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass.cc @@ -0,0 +1,2126 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_group.h" +#include "paddle/pir/core/value.h" + +#include "paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" +#include "paddle/phi/core/flags.h" + +PD_DECLARE_bool(enhance_vertical_fusion_with_recompute); + +namespace cinn { +namespace dialect { +namespace ir { + +using GroupPtr = std::shared_ptr; +using GroupList = std::vector; + +using Comparator = ir::Group::SharedGroupComparator; +using Hasher = ir::Group::SharedGroupHasher; + +using OpGroupPtr = ir::OpGroup; +using OpGroupList = std::vector; + +using ConditionFunction = std::function; + +class FuseHelper { + public: + virtual ~FuseHelper() = default; + + virtual bool AllOutputsSameSize(const OpGroupPtr& first, + const OpGroupPtr& second) const = 0; + + virtual bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ElementwiseFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool HorizontalWithInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool BroadcastFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool InjectiveHorizontalWithReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseElementwise(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool ReduceFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool IsReachable(const OpGroupPtr& lhs, + const OpGroupPtr& rhs) const = 0; + + virtual bool DetectCycleIfFuse(const OpGroupPtr& src, + const OpGroupPtr& dst) const = 0; + + virtual bool IsConsumerSetsReachable( + const OpGroupPtr& group, + const std::unordered_set& consumers) const = 0; + + protected: + FuseHelper() = default; +}; + +template +class GraphGroupFuseHelper final : public FuseHelper { + public: + explicit GraphGroupFuseHelper(const FusePassCtxT* ctx) : ctx_(ctx) {} + + bool AllOutputsSameSize(const OpGroupPtr& first, + const OpGroupPtr& second) const override; + + bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ElementwiseFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool HorizontalWithInjective(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool BroadcastFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool InjectiveHorizontalWithReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ReduceFuseElementwise(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ReduceFuseBroadcast(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool ReduceFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) const override; + + bool IsReachable(const OpGroupPtr& lhs, + const OpGroupPtr& rhs) const override { + return IsReachableInDag(lhs, rhs) || IsReachableInDag(rhs, lhs); + } + + bool DetectCycleIfFuse(const OpGroupPtr& lhs, + const OpGroupPtr& rhs) const override { + return ReachableIfDirectEdgeIgnored(lhs, rhs) || + ReachableIfDirectEdgeIgnored(rhs, lhs); + } + + bool IsConsumerSetsReachable( + const OpGroupPtr& group, + const std::unordered_set& consumers) const override { + for (const auto& consumer : consumers) { + if (group == consumer) { + continue; + } + if (IsReachableInDag(consumer, group)) { + return true; + } + } + return false; + } + + private: + bool IsReachableInDag(const OpGroupPtr& producer, + const OpGroupPtr& consumer) const { + // const auto& MinDepth4Node = [&](const OpGroupPtr& node) { + // return node.GetGroup()->min_depth; + // }; + // const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { + // return node.GetGroup()->max_depth; + // }; + // const auto& VisitNextNodes = + // [&](const OpGroupPtr& node, + // const std::function& Visit) { + // for (const auto& node_producer : node.producers()) { + // Visit(node_producer); + // } + // }; + // common::IsReachablePredicator is_reachable( + // MinDepth4Node, MaxDepth4Node, VisitNextNodes); + // return is_reachable(consumer, producer, [](OpGroupPtr) {}); + // TODO(phlrain) : support IsReachable + return false; + } + + bool ReachableIfDirectEdgeIgnored(const OpGroupPtr& producer, + const OpGroupPtr& consumer) const { + // const auto& MinDepth4Node = [&](const OpGroupPtr& node) { + // return node.GetGroup()->min_depth; + // }; + // const auto& MaxDepth4Node = [&](const OpGroupPtr& node) { + // return node.GetGroup()->max_depth; + // }; + // const auto& VisitNextNodes = + // [&](const OpGroupPtr& node, + // const std::function& Visit) { + // for (const auto& node_producer : node.producers()) { + // if (node == consumer && node_producer == producer) { + // continue; + // } + // Visit(node_producer); + // } + // }; + // common::IsReachablePredicator is_reachable( + // MinDepth4Node, MaxDepth4Node, VisitNextNodes); + // return is_reachable(consumer, producer, [](OpGroupPtr) {}); + // TODO(phlrain) : support IsReachable + return false; + } + + const FusePassCtxT* ctx_; +}; + +class FusePassCtx { + public: + virtual ~FusePassCtx() {} + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void MarkFusible(const OpGroupPtr& first, + const OpGroupPtr& second) = 0; + + protected: + FusePassCtx() = default; +}; + +class LightwareFusePassCtx : public FusePassCtx { + public: + virtual ~LightwareFusePassCtx() {} + + virtual const OpGroupPtr& PickOpGroup() const = 0; + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void MarkFusible(const OpGroupPtr& first, + const OpGroupPtr& second) = 0; + + virtual void MarkFusible(const OpGroupList& candidates) = 0; + + protected: + LightwareFusePassCtx() = default; +}; + +class GraphGroupLightwareFusePassCtx final : public LightwareFusePassCtx { + public: + GraphGroupLightwareFusePassCtx( + const OpGroupPtr& group, + const std::function& MarkFusible) + : group_(group), + MarkFusible_(MarkFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + GraphGroupLightwareFusePassCtx( + const OpGroupPtr& group, + const std::function& + MarkGroupListFusible) + : group_(group), + MarkGroupListFusible_(MarkGroupListFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + const OpGroupPtr& PickOpGroup() const override { return group_; } + + const FuseHelper& fuse_helper() const override { return *fuse_helper_; } + + void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override { + MarkFusible_(first, second); + } + + void MarkFusible(const OpGroupList& candidates) override { + MarkGroupListFusible_(candidates); + } + + private: + const OpGroupPtr& group_; + const std::function + MarkFusible_; + const std::function + MarkGroupListFusible_; + const std::unique_ptr fuse_helper_; +}; + +class InputFusePassCtx : public FusePassCtx { + public: + virtual ~InputFusePassCtx() {} + + virtual const OpGroupList& PickConsumersWithSameInputs() const = 0; + + virtual const FuseHelper& fuse_helper() const = 0; + + virtual void MarkFusible(const OpGroupPtr& first, + const OpGroupPtr& second) = 0; + + virtual void MarkFusible(const OpGroupList& candidates) = 0; + + protected: + InputFusePassCtx() = default; +}; + +class GraphGroupInputFusePassCtx final : public InputFusePassCtx { + public: + GraphGroupInputFusePassCtx( + const OpGroupList& groups, + const std::function& MarkFusible) + : groups_(groups), + MarkFusible_(MarkFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + GraphGroupInputFusePassCtx( + const OpGroupList& groups, + const std::function& + MarkGroupListFusible) + : groups_(groups), + MarkGroupListFusible_(MarkGroupListFusible), + fuse_helper_( + new GraphGroupFuseHelper(this)) {} + + const OpGroupList& PickConsumersWithSameInputs() const override { + return groups_; + } + + const FuseHelper& fuse_helper() const override { return *fuse_helper_; } + + void MarkFusible(const OpGroupPtr& first, const OpGroupPtr& second) override { + MarkFusible_(first, second); + } + + void MarkFusible(const OpGroupList& candidates) override { + MarkGroupListFusible_(candidates); + } + + private: + const OpGroupList& groups_; + const std::function + MarkFusible_; + const std::function + MarkGroupListFusible_; + const std::unique_ptr fuse_helper_; +}; + +template +bool GraphGroupFuseHelper::AllOutputsSameSize( + const OpGroupPtr& first, const OpGroupPtr& second) const { + return is_same_size(first.GetGroup(), second.GetGroup()); +} + +template +bool GraphGroupFuseHelper::HorizontalElementwiseFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return honrizontal_elementwise_fuse_reduce(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ElementwiseFuseBroadcast( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return elementwise_fuse_broadcast(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::HorizontalWithInjective( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return horizontal_with_injective(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ElementwiseFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return elementwise_fuse_reduce(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::BroadcastFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return broadcast_fuse_reduce(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::InjectiveHorizontalWithReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return injective_horizontal_with_reduce(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ReduceFuseElementwise( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_elementwise(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ReduceFuseBroadcast( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_broadcast(src.GetGroup(), dst.GetGroup()); +} + +template +bool GraphGroupFuseHelper::ReduceFuseReduce( + const OpGroupPtr& src, const OpGroupPtr& dst) const { + return reduce_fuse_reduce(src.GetGroup(), dst.GetGroup()); +} + +template +struct HorizontalFuseUtil { + using KindKeyT = std::pair; + + static bool DetectFusabilityByKind(FusePassCtxT* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + const KindKeyT kind_pair(src.kind(), dst.kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + auto out = iter->second(src, dst); + return out; + } + + typedef bool (*ConditionT)(const OpGroupPtr& src, const OpGroupPtr& dst); + + static const std::map& GetConditionMap() { + thread_local static std::map map(RawConditionMap()); + return map; + } + + static std::map RawConditionMap() { + return std::map{ + {{kElementWise, kElementWise}, &IsSameSize}, + {{kElementWise, kBroadcast}, &IsSameSize}, + {{kElementWise, kInjective}, &IsSameSize}, + {{kElementWise, kReduction}, &HorizontalElementwiseFuseReduce}, + + {{kBroadcast, kElementWise}, &IsSameSize}, + {{kBroadcast, kBroadcast}, &IsSameSize}, + {{kBroadcast, kInjective}, &IsSameSize}, + {{kBroadcast, kReduction}, &IsSameSize}, + + {{kInjective, kElementWise}, &IsSameSize}, + {{kInjective, kBroadcast}, &IsSameSize}, + {{kInjective, kInjective}, &IsSameSize}, + {{kInjective, kReduction}, &IsSameSize}, + + {{kReduction, kElementWise}, &HorizontalElementwiseFuseReduce}, + {{kReduction, kBroadcast}, &IsSameSize}, + {{kReduction, kInjective}, &IsSameSize}, + {{kReduction, kReduction}, &ReduceFuseReduce}, + }; + } + + static bool IsSameSize(const OpGroupPtr& src, const OpGroupPtr& dst) { + return cinn::dialect::ir::IsSameSize(src, dst); + } + + static bool HorizontalElementwiseFuseReduce(const OpGroupPtr& src, + const OpGroupPtr& dst) { + // if same shape with horizontal relation + if (IsSameSize(src, dst)) { + return true; + } + + const OpGroupPtr* ele_group = nullptr; + const OpGroupPtr* reduce_group = nullptr; + + if (src.kind() == kReduction) { + ele_group = &dst; + reduce_group = &src; + } else { + ele_group = &src; + reduce_group = &dst; + } + + size_t size_ele = + phi::product(GetMasterNode(*ele_group).outputs()[0].shape()); + + bool can_fuse = false; + reduce_group->WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + size_t size_master = phi::product(op.outputs()[0].shape()); + if (size_ele == size_master) { + can_fuse = true; + } + } + }); + + return can_fuse; + } + + static bool ReduceFuseReduce(const OpGroupPtr& src, const OpGroupPtr& dst) { + // return ctx->fuse_helper().ReduceFuseReduce(src, dst); + return reduce_fuse_reduce(src.GetGroup(), dst.GetGroup()); + } +}; + +class FusePass { + public: + virtual ~FusePass() = default; + + virtual const std::string FuseMode() const = 0; + + virtual int Benefit() const = 0; + + protected: + FusePass() = default; +}; + +class InputFusePass : public FusePass { + public: + virtual ~InputFusePass() = default; + + virtual void operator()(InputFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "InputFuse"; } + + virtual int Benefit() const = 0; + + protected: + InputFusePass() = default; +}; + +class DefaultInputFusePass final : public InputFusePass { + public: + DefaultInputFusePass() : InputFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(InputFusePassCtx* ctx) const override { + const auto& consumer_set = ctx->PickConsumersWithSameInputs(); + + const std::unordered_set consumer_candidates = + [&]() -> std::unordered_set { + std::unordered_set consumers; + for (const auto& consumer : consumer_set) { + if (consumer.kind() == kElementWise || consumer.kind() == kBroadcast || + consumer.kind() == kInjective || consumer.kind() == kReduction) { + consumers.insert(consumer); + } + } + return consumers; + }(); + if (consumer_candidates.size() <= 1) { + return; + } + + std::vector fusionable_consumers; + for (auto& candidate : consumer_candidates) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, + consumer_candidates)) { + continue; + } + if (fusionable_consumers.empty()) { + fusionable_consumers.push_back({candidate}); + continue; + } + // check each fusionable groups + bool fusionable = false; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!HorizontalFuseUtil::DetectFusabilityByKind( + ctx, candidate, last)) { + continue; + } + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + for (const auto& groups : fusionable_consumers) { + if (groups.size() > 1) { + ctx->MarkFusible(groups); + } + } + VLOG(1) << "DefaultInputFusePass Finish"; + } +}; + +class LightwareFusePass : public FusePass { + public: + virtual ~LightwareFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + virtual const std::string FuseMode() const = 0; + + virtual int Benefit() const = 0; + + protected: + LightwareFusePass() = default; +}; + +class HorizontalFusePass : public LightwareFusePass { + public: + virtual ~HorizontalFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "HorizontalFuse"; } + + virtual int Benefit() const = 0; + + protected: + HorizontalFusePass() = default; +}; + +class DefaultHorizontalFusePass final : public HorizontalFusePass { + public: + DefaultHorizontalFusePass() : HorizontalFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const std::unordered_set consumer_candidates = + [&]() -> std::unordered_set { + std::unordered_set consumers; + for (const auto& consumer : producer.consumers()) { + if (consumer.kind() == kElementWise || consumer.kind() == kBroadcast || + consumer.kind() == kInjective || consumer.kind() == kReduction) { + consumers.insert(consumer); + } + } + return consumers; + }(); + if (consumer_candidates.size() <= 1) { + return; + } + + std::vector fusionable_consumers; + for (auto& candidate : consumer_candidates) { + if (ctx->fuse_helper().IsConsumerSetsReachable(candidate, + consumer_candidates)) { + continue; + } + if (fusionable_consumers.empty()) { + fusionable_consumers.push_back({candidate}); + continue; + } + // check each fusionable groups + bool fusionable = false; + for (auto& groups : fusionable_consumers) { + auto& last = groups.back(); + if (!HorizontalFuseUtil::DetectFusabilityByKind( + ctx, candidate, last)) { + continue; + } + groups.push_back(candidate); + fusionable = true; + break; + } + + // if can't fuse to othors Groups, new Groups. + if (!fusionable) { + fusionable_consumers.push_back({candidate}); + } + } + + for (const auto& groups : fusionable_consumers) { + if (groups.size() > 1) { + // Trick for BERT, maybe not required, wait for substitution from + // unordered_set to set + if (groups.size() == 2) { + OpGroupList fuse_group; + if (groups[1].group_id().substr(0, 4) == "cast" && + groups[0].group_id() == "reshape_split") { + fuse_group.push_back(groups[1]); + fuse_group.push_back(groups[0]); + ctx->MarkFusible(fuse_group); + continue; + } + } + ctx->MarkFusible(groups); + } + } + } +}; + +class VerticalFusePass : public LightwareFusePass { + public: + virtual ~VerticalFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "VerticalFuse"; } + + virtual int Benefit() const = 0; + + protected: + VerticalFusePass() = default; +}; + +class DefaultVerticalFusePass final : public VerticalFusePass { + public: + DefaultVerticalFusePass() : VerticalFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); + } + return consumers; + }(); + if (consumers.size() == 0) { + return; + } + + std::vector candidates; + for (size_t i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + break; + } + candidates.push_back(consumer); + } + if (candidates.size() == consumers.size() && + producer.kind() == kElementWise) { + return; + } + + for (size_t i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + VLOG(4) << "Can't fuse because detect cycle"; + continue; + } + ctx->MarkFusible(producer, consumer); + } + } + + using KindKeyT = std::pair; + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src.kind(), dst.kind()); + const auto& map = GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } + + typedef bool (*ConditionT)(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst); + + static const std::map& GetConditionMap() { + thread_local static std::map map(RawConditionMap()); + return map; + } + + static std::map RawConditionMap() { + return std::map{ + {{OpPatternKind::kElementWise, kElementWise}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kElementWise, kBroadcast}, + &DefaultVerticalFusePass::ElementwiseFuseBroadcast}, + {{OpPatternKind::kElementWise, kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kElementWise, kReduction}, + &DefaultVerticalFusePass::ElementwiseFuseReduce}, + + {{OpPatternKind::kBroadcast, kElementWise}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, kBroadcast}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kBroadcast, kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kBroadcast, kReduction}, + &DefaultVerticalFusePass::BroadcastFuseReduce}, + + {{OpPatternKind::kInjective, kElementWise}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, kBroadcast}, + &DefaultVerticalFusePass::IsSameSize}, + {{OpPatternKind::kInjective, kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kInjective, kReduction}, + &DefaultVerticalFusePass::InjectiveHorizontalWithReduce}, + + {{OpPatternKind::kReduction, kElementWise}, + &DefaultVerticalFusePass::ReduceFuseElementwise}, + {{OpPatternKind::kReduction, kBroadcast}, + &DefaultVerticalFusePass::ReduceFuseBroadcast}, + {{OpPatternKind::kReduction, kInjective}, + &DefaultVerticalFusePass::HorizontalWithInjective}, + {{OpPatternKind::kReduction, kReduction}, + &DefaultVerticalFusePass::ReduceFuseReduce}, + }; + } + + static bool IsSameSize(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return cinn::dialect::ir::IsSameSize(src, dst); + } + + static bool ElementwiseFuseBroadcast(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseBroadcast(src, dst); + } + + static bool HorizontalWithInjective(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().HorizontalWithInjective(src, dst); + } + + static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ElementwiseFuseReduce(src, dst); + } + + static bool BroadcastFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().BroadcastFuseReduce(src, dst); + } + + static bool InjectiveHorizontalWithReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().InjectiveHorizontalWithReduce(src, dst); + } + + static bool ReduceFuseElementwise(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseElementwise(src, dst); + } + + static bool ReduceFuseBroadcast(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseBroadcast(src, dst); + } + + static bool ReduceFuseReduce(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) { + return ctx->fuse_helper().ReduceFuseReduce(src, dst); + } +}; + +class RecomputeFusePass : public LightwareFusePass { + public: + virtual ~RecomputeFusePass() = default; + + virtual void operator()(LightwareFusePassCtx* ctx) const = 0; + + const std::string FuseMode() const final { return "RecomputeFuse"; } + + virtual int Benefit() const = 0; + + protected: + RecomputeFusePass() = default; +}; + +class DefaultRecomputeFusePass final : public RecomputeFusePass { + public: + DefaultRecomputeFusePass() : RecomputeFusePass() {} + + int Benefit() const override { return 100; } + + void operator()(LightwareFusePassCtx* ctx) const override { + const auto& producer = ctx->PickOpGroup(); + const OpGroupList consumers = [&]() { + OpGroupList consumers; + for (const auto& consumer : producer.consumers()) { + consumers.push_back(consumer); + } + return consumers; + }(); + // Borrows unsafe_candidates and candidates concept from origin + // fusion_merge_pass + std::vector unsafe_candidates; + std::vector candidates; + for (size_t i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers.at(i); + if (!DetectFusabilityByKind(ctx, producer, consumer)) { + continue; + } + unsafe_candidates.push_back(consumer); + if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { + continue; + } + candidates.push_back(consumer); + } + + if (!candidates.empty() && unsafe_candidates.size() == consumers.size() && + producer.kind() == kElementWise) { + for (const auto& consumer : consumers) { + ctx->MarkFusible(producer, consumer); + } + } + } + + using KindKeyT = std::pair; + bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + const OpGroupPtr& src, + const OpGroupPtr& dst) const { + const KindKeyT kind_pair(src.kind(), dst.kind()); + const auto& map = DefaultVerticalFusePass::GetConditionMap(); + const auto& iter = map.find(kind_pair); + if (iter == map.end()) { + return false; + } + return iter->second(ctx, src, dst); + } +}; + +struct LightwareFusePassComparator { + bool operator()(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) const { + return lhs->Benefit() > rhs->Benefit(); + } +}; + +struct InputFusePassComparator { + bool operator()(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) const { + return lhs->Benefit() > rhs->Benefit(); + } +}; + +class FusionPassMap { + public: + static FusionPassMap& Instance() { + static FusionPassMap global_fusion_pass_map; + return global_fusion_pass_map; + } + + bool Has(const std::string& pass_name) const { + return map_.find(pass_name) != map_.end(); + } + + void Insert(const std::string& pass_name, + const std::shared_ptr& pass) { + CHECK(!Has(pass_name)) << "FusePass " << pass_name + << " has already been registered."; + map_.insert({pass_name, pass}); + } + + std::shared_ptr Get(const std::string& pass_name) const { + auto it = map_.find(pass_name); + CHECK(it != map_.end()) + << "FusePass " << pass_name << " has not been registered."; + return it->second; + } + + // fuse_mode: HorizontalFuse, VerticalFuse, RecomputeFuse + std::vector> GetLightwareFusePassesByMode( + const std::string& fuse_mode) const { + CHECK(fuse_mode == "HorizontalFuse" || fuse_mode == "VerticalFuse" || + fuse_mode == "RecomputeFuse") + << "fuse_mode only supports HorizontalFuse, VerticalFuse and " + "RecomputeFuse. Please check your input modes = " + << fuse_mode; + std::set, LightwareFusePassComparator> + candidate_passes; + for (const auto& iter : map_) { + if (fuse_mode == iter.second->FuseMode()) { + candidate_passes.insert( + std::dynamic_pointer_cast(iter.second)); + } + } + return std::vector>( + candidate_passes.begin(), candidate_passes.end()); + } + + std::vector> GetInputFusePasses() const { + std::set, InputFusePassComparator> + candidate_passes; + for (const auto& iter : map_) { + if (iter.second->FuseMode() == "InputFuse") { + candidate_passes.insert( + std::dynamic_pointer_cast(iter.second)); + } + } + return std::vector>(candidate_passes.begin(), + candidate_passes.end()); + } + + private: + FusionPassMap() = default; + std::unordered_map> map_; + + DISABLE_COPY_AND_ASSIGN(FusionPassMap); +}; + +class Registrar { + public: + // In our design, various kinds of classes, e.g., operators and kernels, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_OP macros to + // call this method. So, as long as the callee code calls USE_OP, the global + // registrar variable won't be removed by the linker. + void Touch() {} +}; + +template +class FusionPassRegistrar final : public Registrar { + public: + explicit FusionPassRegistrar(const std::string& pass_name) { + FusionPassMap::Instance().Insert( + pass_name, std::shared_ptr(new PassClassT())); + } +}; + +// Op Fusion Pass which performs Ops fusion, Ops are fused +// "vertically", meaning producing Ops are fused into their consumers +// with the intent that the loops which compute their values will be fused in +// code generation. +class GeneralFusionMergePassHelper { + public: + explicit GeneralFusionMergePassHelper(const ::pir::Program* graph, + const GroupList& group_list) + : graph_(graph) { + fusion_groups_ = group_list; + // init input to consumers. + InitInputToConsumers(); + // init fusion group index. + InitFusionGroupsAndIndex(); + + if (!FusionPassMap::Instance().Has("DefaultHorizontalFusePass")) { + FusionPassMap::Instance().Insert( + "DefaultHorizontalFusePass", + std::make_shared()); + } + if (!FusionPassMap::Instance().Has("DefaultVerticalFusePass")) { + FusionPassMap::Instance().Insert( + "DefaultVerticalFusePass", + std::make_shared()); + } + + if (!FusionPassMap::Instance().Has("DefaultRecomputeFusePass")) { + FusionPassMap::Instance().Insert( + "DefaultRecomputeFusePass", + std::make_shared()); + } + + if (!FusionPassMap::Instance().Has("DefaultInputFusePass")) { + FusionPassMap::Instance().Insert( + "DefaultInputFusePass", std::make_shared()); + } + } + + GroupList operator()() { + // run fusion merge untill no update. + DoFusionMerge(); + for (auto& group : fusion_groups_) { + VLOG(3) << "Fusion Group -> " << group->group_id; + for (auto& sub_group : group->fused_sub_groups) { + VLOG(3) << " Fused Sub-Group -> " << sub_group->group_id; + } + for (const auto& producer : group->producer_groups()) { + VLOG(3) << " Producer -> " << producer->group_id; + } + for (const auto& consumer : group->consumer_groups()) { + VLOG(3) << " Consumer -> " << consumer->group_id; + } + } + return fusion_groups_; + } + + private: + void DoFusionMerge() { + VLOG(3) << "DoFusionMerge...!"; + while (DoGeneralHorizontalFusion()) { + } + while (DoGeneralVerticalFusion()) { + } + while (DoGeneralRecomputeAndVerticalFusion()) { + } + } + + bool DoGeneralHorizontalFusion() { + VLOG(3) << "DoGeneralHorizontalFusion...!"; + bool updated = false; + for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + updated |= GeneralHorizontalFuse(producer); + } + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + bool DoGeneralVerticalFusion() { + VLOG(3) << "DoGeneralVerticalFusion...!"; + bool updated = false; + for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + updated |= GeneralHorizontalFuse(producer); + updated |= GeneralVerticalFuse(producer); + } + + // fuse input consumers + updated |= GeneralInputFuse(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + bool DoGeneralRecomputeAndVerticalFusion() { + VLOG(3) << "DoGeneralRecomputeAndVerticalFusion...!"; + bool updated = false; + for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { + auto producer = fusion_groups_[idx]; + VLOG(3) << "Fusion Producer idx " << idx << " Group -> " + << producer->group_id; + // if producer is sub group. + if (producer->belong_groups.size()) { + continue; + } + // do horizontal fusion. + bool recompute_success = GeneralRecomputeFuse(producer); + updated |= recompute_success; + if (!recompute_success) { + updated |= GeneralVerticalFuse(producer); + } + } + + // fuse input consumers + updated |= GeneralInputFuse(); + + if (updated) { + UpdateFusionGroup(); + } + return updated; + } + + void UpdateFusionGroup() { + VLOG(3) << "UpdateFusionGroup..."; + GroupList fusion_groups; + std::unordered_set fusion_groups_set; + // update fusion_groups_ + for (auto& group : fusion_groups_) { + if (!group->belong_groups.size()) { + fusion_groups.push_back(group); + fusion_groups_set.insert(group); + } + } + // keep group in order + fusion_groups_.clear(); + fusion_groups_index_.clear(); + while (!fusion_groups_set.empty()) { + bool is_ring = true; + for (size_t idx = 0; idx < fusion_groups.size(); ++idx) { + auto& group = fusion_groups[idx]; + if (!group.get()) { + continue; + } + + bool exist = false; + for (const auto& producer : group->producer_groups()) { + if (fusion_groups_set.count(producer)) { + VLOG(4) << group->group_id << " " << producer->group_id; + exist = true; + break; + } + } + + if (!exist) { + fusion_groups_index_[group] = fusion_groups_.size(); + fusion_groups_.push_back(group); + fusion_groups_set.erase(group); + group.reset(); + is_ring = false; + continue; + } + } + if (is_ring) { + LOG(FATAL) << "Exists Ring, Please Check!"; + } + } + } + + std::vector> RawHorizontalFusePasses() + const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode( + "HorizontalFuse"); + } + + const std::vector>& + GetHorizontalFusePasses() const { + thread_local static std::vector> + fuse_passes = RawHorizontalFusePasses(); + return fuse_passes; + } + + void EnableFusedHorizontalGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer.consumers().size() <= 1) { + return; + } + const auto& fuse_passes = GetHorizontalFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralHorizontalFuse(const GroupPtr& producer) { + VLOG(3) << "GeneralHorizontalFuse handling producer : " + << producer->group_id; + const auto& GetFusableConsumerGroupLists = + [&]() -> std::vector { + std::vector tagged_lists; + const auto& MarkFusible = [&](const OpGroupList& candidates) { + tagged_lists.push_back(candidates); + }; + GraphGroupLightwareFusePassCtx fuse_ctx(ir::OpGroup(producer), + MarkFusible); + EnableFusedHorizontalGroups(&fuse_ctx); + return tagged_lists; + }; + const auto& GetFusableConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusableConsumerGroupLists(); + if (group_lists.empty()) { + return std::vector{}; + } + std::vector ret; + for (const auto& group_list : group_lists) { + GroupList tmp; + for (const auto& group : group_list) { + tmp.push_back(group.GetGroup()); + } + ret.push_back(tmp); + } + return ret; + }; + + const auto& group_lists = GetFusableConsumerGroupList(); + if (group_lists.empty()) { + return false; + } + for (const auto& group_list : group_lists) { + HorizontalFuse(group_list); + } + + return true; + } + + std::vector> RawInputFusePasses() const { + return FusionPassMap::Instance().GetInputFusePasses(); + } + + const std::vector>& GetInputFusePasses() + const { + thread_local static std::vector> + fuse_passes = RawInputFusePasses(); + return fuse_passes; + } + + void EnableFusedInputGroups(InputFusePassCtx* ctx) const { + const auto& fuse_passes = GetInputFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool CallGeneralInputFusePass( + const std::unordered_set& consumers) { + VLOG(3) << "CallGeneralInputFusePass...!"; + const auto& GetFusableConsumerGroupLists = + [&]() -> std::vector { + std::vector tagged_lists; + const auto& MarkFusible = [&](const OpGroupList& candidates) { + tagged_lists.push_back(candidates); + }; + OpGroupList consumer_groups; + consumer_groups.reserve(consumers.size()); + for (auto& consumer : consumers) { + consumer_groups.push_back(ir::OpGroup(consumer)); + } + GraphGroupInputFusePassCtx fuse_ctx(consumer_groups, MarkFusible); + EnableFusedInputGroups(&fuse_ctx); + return tagged_lists; + }; + const auto& GetFusableConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusableConsumerGroupLists(); + if (group_lists.empty()) { + return std::vector{}; + } + std::vector ret; + for (const auto& group_list : group_lists) { + GroupList tmp; + for (const auto& group : group_list) { + tmp.push_back(group.GetGroup()); + } + ret.push_back(tmp); + } + return ret; + }; + + const auto& group_lists = GetFusableConsumerGroupList(); + if (group_lists.empty()) { + return false; + } + for (const auto& group_list : group_lists) { + HorizontalFuse(group_list); + } + + return true; + } + + void HorizontalFuse(const GroupList& consumers) { + VLOG(3) << "HorizontalFuse Groups..."; + // create fusion group + auto fused_group = std::make_shared(); + // As recompute exist which may case sub-group used by more than one time. + std::vector repeat_sub_groups; + std::unordered_set sub_group_set; + // find the first consumer. + GroupPtr first_consumer(nullptr); + // fuse all group into fusion group. + for (const auto& consumer : consumers) { + VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!"; + // update depth + fused_group->max_depth = + std::max(fused_group->max_depth, consumer->max_depth); + fused_group->min_depth = + std::min(fused_group->min_depth, consumer->min_depth); + // update group id + if (fused_group->group_id.size()) { + fused_group->group_id += "_" + consumer->group_id; + } else { + fused_group->group_id = consumer->group_id; + } + // set op pattern kind + fused_group->op_pattern_kind = + static_cast(fused_group->op_pattern_kind) >= + static_cast(consumer->op_pattern_kind) + ? fused_group->op_pattern_kind + : consumer->op_pattern_kind; + // input nodes + for (auto& node : consumer->input_nodes) { + if (fused_group->input_nodes.count(node.first)) { + fused_group->input_nodes[node.first] += node.second; + } else { + fused_group->input_nodes.insert(node); + } + } + // output node + for (auto& node : consumer->output_nodes) { + fused_group->output_nodes.insert(node); + } + // internal node + if (consumer->fused_sub_groups.size()) { + for (auto& node : consumer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + // master node + for (auto& node : consumer->master_nodes) { + if (GetOpKind(node->name()) == kReduction) { + fused_group->master_nodes.insert(node); + } + } + // insert sub group + if (consumer->fused_sub_groups.size()) { + for (auto& sub_group : consumer->fused_sub_groups) { + // check sub group is repeat. + if (sub_group_set.count(sub_group)) { + VLOG(3) << sub_group->group_id << " is repeated!"; + repeat_sub_groups.push_back(sub_group); + continue; + } + // record sub group + sub_group_set.insert(sub_group); + + // insert to fused sub group. + fused_group->fused_sub_groups.push_back(sub_group); + // update belongs group + sub_group->belong_groups.erase(consumer); + sub_group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(consumer); + } + // producer group + for (auto& producer : *consumer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(producer); + // update producer's consumer + producer->mut_consumer_groups()->erase(consumer); + producer->mut_consumer_groups()->insert(fused_group); + } + // consumer group + for (auto& gconsumer : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(gconsumer); + // update consumer's producer + gconsumer->mut_producer_groups()->erase(consumer); + gconsumer->mut_producer_groups()->insert(fused_group); + } + // belongs group + consumer->belong_groups.insert(fused_group); + + // find the first consumer. + CHECK(fusion_groups_index_.count(consumer)) + << "Can't find consumer " << consumer->group_id + << " index in fusion_groups_index_!"; + if (first_consumer.get()) { + if (fusion_groups_index_[consumer] < + fusion_groups_index_[first_consumer]) { + first_consumer = consumer; + } + } else { + first_consumer = consumer; + } + } + + // if node is output nodes of sub_group, check it can't be internal node. + for (auto& sub_group : repeat_sub_groups) { + // check each output node in sub_group. + for (auto& node : sub_group->output_nodes) { + // if node is not output node of fused_group. + if (!fused_group->output_nodes.count(node)) { + fused_group->internal_nodes.insert(node); + } + } + } + + if (static_cast(kReduction) > + static_cast((consumers.back())->op_pattern_kind)) { + auto consumer = consumers.back(); + + for (auto& node : consumer->master_nodes) { + fused_group->master_nodes.insert(node); + } + } else { + for (auto consumer = consumers.rbegin(); consumer != consumers.rend(); + ++consumer) { + ::pir::Operation* master_node = nullptr; + for (auto& node : (*consumer)->master_nodes) { + if (GetOpKind(node->name()) != kReduction) { + master_node = node; + break; + } + } + if (master_node) { + // VLOG(3) << "Insert Master node : " << master_node->id() + // << " into group : " << fused_group->group_id; + fused_group->master_nodes.insert(master_node); + break; + } + } + } + + auto postion = fusion_groups_index_[first_consumer]; + fusion_groups_[postion] = fused_group; + fusion_groups_index_[fused_group] = postion; + + CHECK(fused_group->output_nodes.size()) + << "No output node is found, " << fused_group->group_id; + } + + std::vector> RawVerticalFusePasses() + const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode( + "VerticalFuse"); + } + + const std::vector>& GetVerticalFusePasses() + const { + thread_local static std::vector> + fuse_passes = RawVerticalFusePasses(); + return fuse_passes; + } + + void TagVerticalGroups(LightwareFusePassCtx* ctx) const { + const auto& producer = ctx->PickOpGroup(); + if (producer.consumers().size() == 0) { + return; + } + const auto& fuse_passes = GetVerticalFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralVerticalFuse(const GroupPtr& producer) { + VLOG(3) << "GeneralVerticalFuse...!"; + using GroupSets = std::vector>; + const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& MarkFusible = [&](const OpGroupPtr& first, + const OpGroupPtr& second) { + tagged_sets.push_back(std::make_pair(first, second)); + }; + GraphGroupLightwareFusePassCtx fuse_ctx(ir::OpGroup(producer), + MarkFusible); + TagVerticalGroups(&fuse_ctx); + return tagged_sets; + }; + + auto GetFusableConsumerGroupSet = + [&]() -> std::unordered_set { + const auto& group_sets = GetFusableConsumerOpGroupSets(); + if (group_sets.empty()) { + return {}; + } + std::unordered_set ret; + for (const auto& group_pair : group_sets) { + ret.insert(group_pair.second.GetGroup()); + } + return ret; + }; + + bool update = false; + auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size()) { + SelectConsumerToFuse(producer, &consumer_groups); + } + if (consumer_groups.size() > 0) { + VerticalFuse(producer, consumer_groups); + update = true; + } + return update; + } + + void VerticalFuse(const GroupPtr& producer, + const std::unordered_set& + fusionable_consumers) { + VLOG(3) << "VerticalFuse...!"; + GroupList fused_groups; + GroupPtr master_fuesd_group(nullptr); + for (auto& consumer : fusionable_consumers) { + auto fused_group = std::make_shared(); + // update depth using consumer depth. + fused_group->max_depth = + std::max(producer->max_depth, consumer->max_depth); + fused_group->min_depth = + std::min(producer->min_depth, consumer->min_depth); + // update group id + fused_group->group_id = producer->group_id + "_" + consumer->group_id; + VLOG(3) << "fuse producer " << producer->group_id << " into consumer " + << consumer->group_id; + // fuse producer into fusion group + fused_group->op_pattern_kind = + static_cast(producer->op_pattern_kind) >= + static_cast(consumer->op_pattern_kind) + ? producer->op_pattern_kind + : consumer->op_pattern_kind; + // input nodes + fused_group->input_nodes = producer->input_nodes; + + // internal nodes + if (producer->fused_sub_groups.size()) { + for (auto& node : producer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + // convert producer's output node to internal. + for (auto node : producer->output_nodes) { + // if node is used more than 1 time. + if (consumer->input_nodes.count(node)) { + if (consumer->input_nodes[node] > 1 && node->num_operands() > 0) { + fused_group->internal_nodes.insert(node); + } + } + } + // master nodes + for (auto& node : producer->master_nodes) { + if (GetOpKind(node->name()) == kReduction) { + fused_group->master_nodes.insert(node); + } + } + + // producer groups + for (auto& group : *producer->mut_producer_groups()) { + fused_group->mut_producer_groups()->insert(group); + // update producer's producer's consumer + group->mut_consumer_groups()->erase(producer); + group->mut_consumer_groups()->insert(fused_group); + } + + // sub groups + if (producer->fused_sub_groups.size()) { + for (auto& group : producer->fused_sub_groups) { + fused_group->fused_sub_groups.push_back(group); + // update belong group + group->belong_groups.erase(producer); + group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(producer); + } + producer->belong_groups.insert(fused_group); + + // input nodes + for (auto& input_node : consumer->input_nodes) { + // if input node not in producer output. + if (!producer->output_nodes.count(input_node.first)) { + if (fused_group->input_nodes.count(input_node.first)) { + fused_group->input_nodes[input_node.first] += input_node.second; + } else { + fused_group->input_nodes.insert(input_node); + } + } + } + + // output nodes + for (auto& node : consumer->output_nodes) { + fused_group->output_nodes.insert(node); + } + + // internal nodes + if (consumer->fused_sub_groups.size()) { + for (auto& node : consumer->internal_nodes) { + fused_group->internal_nodes.insert(node); + } + } + + // master nodes + for (auto& node : consumer->master_nodes) { + fused_group->master_nodes.insert(node); + } + + // producer nodes + for (auto& group : *consumer->mut_producer_groups()) { + if (group.get() != producer.get()) { + fused_group->mut_producer_groups()->insert(group); + // update consumer's producer's consumer + group->mut_consumer_groups()->erase(consumer); + group->mut_consumer_groups()->insert(fused_group); + } + } + + // consumer nodes + for (auto& group : *consumer->mut_consumer_groups()) { + fused_group->mut_consumer_groups()->insert(group); + // update consumer's consumer's producer + group->mut_producer_groups()->erase(consumer); + group->mut_producer_groups()->insert(fused_group); + } + + // sub group + if (consumer->fused_sub_groups.size()) { + for (auto& sub_group : consumer->fused_sub_groups) { + if (std::find(fused_group->fused_sub_groups.begin(), + fused_group->fused_sub_groups.end(), + sub_group) == fused_group->fused_sub_groups.end()) { + fused_group->fused_sub_groups.push_back(sub_group); + } + // update belong group + sub_group->belong_groups.erase(consumer); + sub_group->belong_groups.insert(fused_group); + } + } else { + fused_group->fused_sub_groups.push_back(consumer); + } + consumer->belong_groups.insert(fused_group); + + fused_groups.push_back(fused_group); + CHECK(fusion_groups_index_.count(consumer)) + << "Can't find consumer " << consumer->group_id + << " index in fusion_groups_index_!"; + auto postion = fusion_groups_index_[consumer]; + fusion_groups_[postion] = fused_group; + fusion_groups_index_[fused_group] = postion; + + if (!master_fuesd_group.get()) { + master_fuesd_group = fused_group; + } + CHECK(fused_group->output_nodes.size()) + << "No output node is found, " << fused_group->group_id; + } + + for (auto& node : producer->output_nodes) { + bool be_output = true; + for (const auto& consumer : producer->consumer_groups()) { + // if consumer is in fusionable. + if (fusionable_consumers.count(consumer)) { + if (consumer->input_nodes.count(node)) { + be_output = false; + } + continue; + } + // if consumer is not in fusionable. + if (consumer->input_nodes.count(node)) { + be_output = true; + break; + } + // others node is as graph output. + } + + if (output_nodes_set_.count(node)) { + be_output = true; + } + + if (be_output) { + // VLOG(4) << "Insert Id " << node->id() << " Into Group " + // << master_fuesd_group->group_id; + master_fuesd_group->output_nodes.insert(node); + } + } + // insert unfusionable consumer groups + for (auto& consumer : *producer->mut_consumer_groups()) { + if (fusionable_consumers.count(consumer)) { + continue; + } + master_fuesd_group->mut_consumer_groups()->insert(consumer); + // update consumer's producer + consumer->mut_producer_groups()->erase(producer); + consumer->mut_producer_groups()->insert(master_fuesd_group); + } + } + + std::vector> RawRecomputeFusePasses() + const { + return FusionPassMap::Instance().GetLightwareFusePassesByMode( + "RecomputeFuse"); + } + + const std::vector>& + GetRecomputeFusePasses() const { + thread_local static std::vector> + fuse_passes = RawRecomputeFusePasses(); + return fuse_passes; + } + + void TagRecomputeGroups(LightwareFusePassCtx* ctx) const { + const auto& fuse_passes = GetRecomputeFusePasses(); + for (const auto& fuse_pass : fuse_passes) { + (*fuse_pass)(ctx); + } + } + + bool GeneralRecomputeFuse(const GroupPtr& producer) { + VLOG(3) << "GeneralRecomputeFuse handling producer : " + << producer->group_id; + using GroupSets = std::set>; + const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + GroupSets tagged_sets; + const auto& MarkFusible = [&](const OpGroupPtr& first, + const OpGroupPtr& second) { + tagged_sets.insert(std::make_pair(first, second)); + }; + GraphGroupLightwareFusePassCtx fuse_ctx(ir::OpGroup(producer), + MarkFusible); + TagRecomputeGroups(&fuse_ctx); + return tagged_sets; + }; + + auto GetFusableConsumerGroupSet = + [&]() -> std::unordered_set { + const auto& group_sets = GetFusableConsumerOpGroupSets(); + if (group_sets.empty()) { + return {}; + } + std::unordered_set ret; + for (const auto& group_pair : group_sets) { + ret.insert(group_pair.second.GetGroup()); + } + return ret; + }; + + bool update = false; + auto consumer_groups = GetFusableConsumerGroupSet(); + if (consumer_groups.size() > 0) { + CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) + << "Recompute requires fuse all consumers!"; + RecomputeFuse(producer, consumer_groups); + update = true; + } + return update; + } + + void RecomputeFuse(const GroupPtr& producer, + const std::unordered_set& + fusionable_consumers) { + VerticalFuse(producer, fusionable_consumers); + } + + void SelectConsumerToFuse( + const GroupPtr& producer, + std::unordered_set* fusionable_consumers) { + // if is const op + + // TODO(phlrain) : support constant + // if (is_const_group(this, producer)) { + if (false) { + std::unordered_set candidates; + for (auto& consumer : *fusionable_consumers) { + // if can be output node. + if (is_same_shape(producer, consumer)) { + candidates.insert(consumer); + } else { + VLOG(4) << "Fuse Producer : " << producer->group_id + << " into Consumer : " << consumer->group_id; + consumer->group_id = producer->group_id + "_" + consumer->group_id; + // just merge the node into group. + auto& sub_group = consumer->fused_sub_groups.front(); + sub_group->group_id = producer->group_id + "_" + sub_group->group_id; + sub_group->nodes.insert(sub_group->nodes.begin(), + producer->CollectNodes()[0]); + sub_group->nodes_set.insert(producer->CollectNodes()[0]); + // remove depency. + consumer->input_nodes.erase(producer->CollectNodes()[0]); + consumer->mut_producer_groups()->erase(producer); + producer->mut_consumer_groups()->erase(consumer); + } + } + + CHECK_GE(producer->consumer_groups().size(), candidates.size()); + if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && + output_nodes_set_.count(producer->CollectNodes()[0]) == 0) { + producer->belong_groups.insert(*fusionable_consumers->begin()); + } + + *fusionable_consumers = candidates; + return; + } + // 1 to 1 fusion. + if (producer->consumer_groups().size() == 1) { + return; + } + + // TODO(phlrain): support flags + // if (FLAGS_enhance_vertical_fusion_with_recompute) { + if (false) { + std::vector candidates; + for (auto& consumer : *fusionable_consumers) { + if (consumer->op_pattern_kind == kElementWise) { + candidates.push_back(consumer); + continue; + } + + auto producer_output_shape = phi::vectorize( + GetValueShape((*producer->output_nodes.begin())->result(0))); + + auto consumer_output_shape = phi::vectorize( + GetValueShape((*consumer->output_nodes.begin())->result(0))); + + auto consumer_master_input_shape = phi::vectorize(GetValueShape( + (*(consumer->master_nodes.begin()))->operand_source(0))); + + int producer_output_numel = + std::accumulate(producer_output_shape.begin(), + producer_output_shape.end(), + 1, + std::multiplies()); + int consumer_output_numel = + std::accumulate(consumer_output_shape.begin(), + consumer_output_shape.end(), + 1, + std::multiplies()); + int consumer_master_input_numel = + std::accumulate(consumer_master_input_shape.begin(), + consumer_master_input_shape.end(), + 1, + std::multiplies()); + if (producer_output_numel == consumer_output_numel) { + candidates.push_back(consumer); + continue; + } + + if (producer->op_pattern_kind != kInjective && + consumer->op_pattern_kind == kReduction && + producer_output_numel == consumer_master_input_numel) { + candidates.push_back(consumer); + } + } + sort(candidates.begin(), + candidates.end(), + [](const auto& lhs, const auto& rhs) { + return lhs->op_pattern_kind < rhs->op_pattern_kind; + }); + + fusionable_consumers->clear(); + if (candidates.size()) { + fusionable_consumers->insert(*candidates.begin()); + } + } else { + std::vector candidates; + for (auto& consumer : *fusionable_consumers) { + if (consumer->op_pattern_kind == kElementWise) { + candidates.push_back(consumer); + continue; + } + + auto shape0 = phi::vectorize( + GetValueShape((*producer->output_nodes.begin())->result(0))); + auto shape1 = phi::vectorize( + GetValueShape((*consumer->output_nodes.begin())->result(0))); + + if (std::accumulate( + shape0.begin(), shape0.end(), 1, std::multiplies()) == + std::accumulate( + shape1.begin(), shape1.end(), 1, std::multiplies())) { + candidates.push_back(consumer); + } + } + + fusionable_consumers->clear(); + if (candidates.size()) { + fusionable_consumers->insert(candidates.front()); + } + } + } + + bool IsDependency( + const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (const auto& producer_and_list : candidate->producer_groups()) { + if (producer_and_list.get() == producer_g.get()) { + continue; + } + const auto& producer = + std::dynamic_pointer_cast(producer_and_list); + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; + } + + bool IsDependencySimplify( + const GroupPtr& producer_g, + const GroupPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + // check upper. + int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX; + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (auto& producer_and_list : candidate->producer_groups()) { + if (producer_and_list.get() == producer_g.get()) { + continue; + } + const auto& producer = + std::dynamic_pointer_cast(producer_and_list); + if (producer->min_depth > check_upper_depth) { + continue; + } + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; + } + + bool GeneralInputFuse() { + VLOG(3) << "GeneralInputFuse...!"; + auto updated = false; + UpdateInputToConsumers(); + for (auto& input_consumers : input_to_consumers_) { + // if group set size == 1. + if (input_consumers.second.size() == 1) { + continue; + } + // do input fusion. + auto st = CallGeneralInputFusePass(input_consumers.second); + if (st) { + // fused consumers, update + UpdateInputToConsumers(); + } + updated |= st; + } + + return updated; + } + + void UpdateInputToConsumers() { + for (auto& input_consumers : input_to_consumers_) { + auto& consumers = input_consumers.second; + std::unordered_set updated_consumers; + for (auto& consumer : consumers) { + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(belong_group); + } else { + fused_groups.push(belong_group); + } + } + } + } + } + consumers = updated_consumers; + } + } + + void InitInputToConsumers() { + VLOG(3) << "InitInputToConsumers...!"; + // init input data node -> fusion group map. + for (auto& group : fusion_groups_) { + for (auto& node : group->nodes_set) { + // collect producer node data. + for (size_t i = 0; i < node->num_operands(); ++i) { + auto in = node->operand_source(i); + if (in) { + input_to_consumers_[in].insert(group); + } + } + } + } + } + + void InitFusionGroupsAndIndex() { + VLOG(3) << "InitFusionGroupsAndIndex...!"; + // init the postion of groups in fusion groups. + for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { + auto group = fusion_groups_[idx]; + auto belong_group = std::make_shared(); + // copy from group. + belong_group->max_depth = group->depth; + belong_group->min_depth = group->depth; + belong_group->group_id = group->group_id; + belong_group->input_nodes = group->input_nodes; + belong_group->output_nodes = group->output_nodes; + belong_group->op_pattern_kind = group->op_pattern_kind; + belong_group->master_nodes = group->master_nodes; + (*belong_group->mut_producer_groups()) = group->producer_groups(); + (*belong_group->mut_consumer_groups()) = group->consumer_groups(); + belong_group->fused_sub_groups.push_back(group); + group->belong_groups.insert(belong_group); + // replace group to fused_group + fusion_groups_[idx] = belong_group; + // record idx + fusion_groups_index_[belong_group] = idx; + } + + // update producer and consumer. + for (auto& group : fusion_groups_) { + std::unordered_set producers; + std::unordered_set consumers; + + for (const auto& producer : group->producer_groups()) { + CHECK(producer->belong_groups.size()); + producers.insert(*producer->belong_groups.begin()); + } + + for (auto& consumer : *group->mut_consumer_groups()) { + CHECK(consumer->belong_groups.size()); + consumers.insert(*consumer->belong_groups.begin()); + } + CHECK_EQ(group->producer_groups().size(), producers.size()); + CHECK_EQ(group->consumer_groups().size(), consumers.size()); + (*group->mut_producer_groups()) = producers; + (*group->mut_consumer_groups()) = consumers; + } + } + + const ::pir::Program* graph_; + GroupList fusion_groups_; + std::unordered_map fusion_groups_index_; + std::unordered_set output_nodes_set_; + std::unordered_map<::pir::Value, + std::unordered_set> + input_to_consumers_; +}; + +GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph, + const GroupList& group_list) { + if (group_list.size() <= 1) { + VLOG(3) << "Don't do Fusoin Merge Pass...!"; + return group_list; + } + + GeneralFusionMergePassHelper fusion_merge_pass_helper(graph, group_list); + auto res = fusion_merge_pass_helper(); + + return res; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h new file mode 100644 index 00000000000000..19ea891531b872 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h @@ -0,0 +1,279 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_group.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" + +namespace cinn { +namespace dialect { +namespace ir { + +using OpGroupPtr = ir::OpGroup; +using OpGroupList = std::vector; + +static cinn::dialect::ir::OpNode GetMasterNode(const OpGroupPtr& op_group) { + std::vector master_nodes; + op_group.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + master_nodes.push_back(op); + } + }); + if (!master_nodes.empty()) { + return master_nodes.front(); + } + + op_group.WalkOpNodes( + [&](const cinn::dialect::ir::OpNode& op) { master_nodes.push_back(op); }); + return master_nodes.back(); +} + +static bool IsSameSize(const OpGroupPtr& src, const OpGroupPtr& dst) { + cinn::dialect::ir::OpNode src_master_node = GetMasterNode(src); + cinn::dialect::ir::OpNode dst_master_node = GetMasterNode(dst); + + auto size_0 = src_master_node.outputs()[0].shape(); + auto size_1 = dst_master_node.outputs()[0].shape(); + + return phi::product(size_0) == phi::product(size_1); +} + +static std::unordered_set GetInputOps( + const OpGroupPtr& op_group) { + std::unordered_set ops_set; + op_group.WalkOpNodes([&ops_set](const cinn::dialect::ir::OpNode& op_node) { + ops_set.insert(op_node); + }); + + std::unordered_set input_ops; + op_group.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + const auto& input_tensors = op.inputs(); + for (size_t i = 0; i < input_tensors.size(); ++i) { + if (!ops_set.count(input_tensors[i].producer())) { + input_ops.insert(input_tensors[i].producer()); + } + } + }); + return input_ops; +} + +static std::unordered_set GetOutputOps( + const OpGroupPtr& op_group) { + std::unordered_set ops_set; + op_group.WalkOpNodes([&ops_set](const cinn::dialect::ir::OpNode& op_node) { + ops_set.insert(op_node); + }); + std::unordered_set output_ops; + op_group.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + const auto& output_tensors = op.outputs(); + for (size_t i = 0; i < output_tensors.size(); ++i) { + auto& consumers = output_tensors[i].consumers(); + for (auto it = consumers.begin(); it != consumers.end(); ++it) { + if (!ops_set.count(*it)) { + output_ops.insert(*it); + break; + } + } + } + }); + return output_ops; +} + +// limit the group args number to less equal 512, as args stack size is 4K. +static bool limit_args(const OpGroupPtr& first, const OpGroupPtr& second) { + std::unordered_set args; + for (auto& group : {first, second}) { + for (const auto& node : GetInputOps(group)) { + args.insert(node); + } + for (const auto& node : GetOutputOps(group)) { + args.insert(node); + } + } + + if (args.size() > 512) { + return false; + } else { + return true; + } +} + +bool WithoutLastDimInReduce(const phi::DDim& inshape, + const std::vector& axes) { + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int sum_last_axes = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + sum_last_axes *= inshape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +static int GetSharedSize(const cinn::dialect::ir::OpNode& op_node) { + const auto& inshape = op_node.inputs()[0].shape(); + // const auto& axes = op_node.GetAttr>("dim"); + // const auto& axes = op_node.Op()->attributes().at("dim").dyn_cast<> + // TODO(phlrain): get vector from attribute + std::vector axes = {1}; + if (WithoutLastDimInReduce(inshape, axes)) { + int lane = 1; + for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { + lane = inshape[idx]; + } + // int max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int max_num_threads = 1000; + if (lane > max_num_threads / 2) { + return 0; + } + int index = axes.size() - 1; + for (; index >= 0; --index) { + if (static_cast(index + 1) < axes.size() && + axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (lane > max_num_threads / 2) { + break; + } + } + // if lane > (max_num_threads / 2),the loop break from lane > + // max_num_threads / 2. + int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; + if (lane <= max_num_threads) { + return lane * sizeof(float); + } else { + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; + idx > ((max_num_threads / 2) / tail); + --idx) { + if (prefix % idx == 0) { + return idx * tail * sizeof(float); + } + } + int num = max_num_threads / tail; + return num * tail * sizeof(float); + } + } + return 0; +} + +static bool ReduceFuseReduce(const OpGroupPtr& first, + const OpGroupPtr& second) { + if (!limit_args(first, second)) { + return false; + } + std::unique_ptr reducer_0 = nullptr; + first.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + if (!reducer_0 && op.kind() == kReduction) { + reducer_0.reset(new cinn::dialect::ir::OpNode(op)); + } + }); + CHECK(reducer_0) << "Can't find reduce op in group " << first.group_id(); + + std::unique_ptr reducer_1 = nullptr; + second.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + if (!reducer_1 && op.kind() == kReduction) { + reducer_1.reset(new cinn::dialect::ir::OpNode(op)); + } + }); + + CHECK(reducer_1) << "Can't find reduce op in group " << second.group_id(); + + // check reduce has same input shape and output shape + const auto& reducer_0_input_shape = reducer_0->inputs()[0].shape(); + const auto& reducer_0_output_shape = reducer_0->outputs()[0].shape(); + + const auto& reducer_1_input_shape = reducer_1->inputs()[0].shape(); + const auto& reducer_1_output_shape = reducer_1->outputs()[0].shape(); + + // TODO(phlrain): get attribute from op node + // auto reducer_0_reduce_dim = reducer_0->GetAttr>("dim"); + // auto reducer_1_reduce_dim = reducer_1->GetAttr>("dim"); + + std::vector reducer_0_reduce_dim = {0}; + std::vector reducer_1_reduce_dim = {0}; + + for (auto& dim : reducer_0_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_0_reduce_dim.size() - 1; + } + } + + for (auto& dim : reducer_1_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_1_reduce_dim.size() - 1; + } + } + + // check shape is same + if (reducer_0_input_shape == reducer_1_input_shape && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + fusion_group.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); + } + }); + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && + WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + fusion_group.WalkOpNodes([&](const cinn::dialect::ir::OpNode& op) { + if (op.kind() == OpPatternKind::kReduction) { + shared_size += GetSharedSize(op); + } + }); + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + return false; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h new file mode 100644 index 00000000000000..1b8f5b6aeacd7b --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_util.h @@ -0,0 +1,638 @@ +// Copyright (c) 20223 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" + +namespace cinn { +namespace dialect { +namespace ir { + +const std::set ConstantOps = { + "const_scalar", "fill_constant", "arange"}; + +// limit the group args number to less equal 512, as args stack size is 4K. +inline bool limit_args(const std::shared_ptr& first, + const std::shared_ptr& second) { + std::unordered_set args; + for (auto& group : {first, second}) { + for (auto node : group->input_nodes) { + args.insert(node.first); + } + for (auto node : group->output_nodes) { + args.insert(node); + } + } + + if (args.size() > 512) { + return false; + } else { + return true; + } +} + +inline bool always_fuse(const std::shared_ptr& first, + const std::shared_ptr& second) { + return true; +} + +inline bool is_same_shape(const std::shared_ptr& first, + const std::shared_ptr& second) { + if (!limit_args(first, second)) { + return false; + } + + auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0)); + auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0)); + return output_var_0 == output_var_1; +} + +inline bool is_same_size(const std::shared_ptr& first, + const std::shared_ptr& second) { + if (!limit_args(first, second)) { + return false; + } + + auto output_var_0 = GetValueShape((*first->master_nodes.begin())->result(0)); + auto output_var_1 = GetValueShape((*second->master_nodes.begin())->result(0)); + if (output_var_0 == output_var_1) { + return true; + } + + auto size_0 = phi::product(output_var_0); + auto size_1 = phi::product(output_var_1); + return size_0 == size_1; +} + +inline bool is_const_group(const std::shared_ptr& group) { + return group->CollectNodes().size() == 1 && + ConstantOps.count(group->CollectNodes()[0]->name()); +} + +inline bool elementwise_fuse_broadcast( + const std::shared_ptr& first, + const std::shared_ptr& second) { + // if producer just include const op. + if (is_const_group(first)) { + return true; + } + // if same shape with horizontal relation + if (is_same_size(first, second)) { + return true; + } + // if first's output is not all in second's input + for (auto output : first->output_nodes) { + return true; + if (!second->input_nodes.count(output)) { + return false; + } + + // TODO(phlrain): support output set here + // if (helper->output_nodes_set_.count(output)) { + // return false; + // } + + return true; + } + // 1.compute io-size + // 2.compute computation-size + // 3.compute recompute-times + // 4.compute cost + // TODO(sunli) : cost-model. + return true; +} + +inline bool honrizontal_elementwise_fuse_reduce( + const std::shared_ptr& first, + const std::shared_ptr& second) { + std::shared_ptr ele_group, reduce_group; + if (first->op_pattern_kind == kReduction) { + ele_group = second; + reduce_group = first; + } else { + ele_group = first; + reduce_group = second; + } + // if same shape with horizontal relation + if (is_same_size(first, second)) { + return true; + } + + auto ele_node_shape = + GetValueShape((*ele_group->master_nodes.begin())->result(0)); + int32_t size_ele = phi::product(ele_node_shape); + // TODO(phlrain): seems extrame danger herem, why compare multi Master Node? + for (auto* master : reduce_group->master_nodes) { + auto master_node_shape = GetValueShape(master->result(0)); + int32_t size_master = phi::product(master_node_shape); + if (size_ele == size_master) { + return true; + } + } + + return false; +} + +inline bool elementwise_fuse_reduce(const std::shared_ptr& first, + const std::shared_ptr& second) { + // if (helper->target_ == common::DefaultHostTarget()) { + // return true; + // } + // if same shape with horizontal relation + if (is_same_size(first, second)) { + return true; + } + + // if reduce nodes not in consumers of first group + std::queue<::pir::Operation*> candidates; + std::unordered_set<::pir::Operation*> first_node_set = first->NodeSet(); + std::unordered_set<::pir::Operation*> second_node_set = second->NodeSet(); + for (const auto& pair : second->input_nodes) { + if (first_node_set.find(pair.first) != first_node_set.end()) { + candidates.push(pair.first); + } + } + std::unordered_set<::pir::Operation*> visited; + std::unordered_set<::pir::Operation*> masters_in_consumers; + + while (!candidates.empty()) { + ::pir::Operation* candidate = candidates.front(); + candidates.pop(); + + // TODO(phlrain) : why only deal with first output + auto first_output = candidate->result(0); + for (auto it = first_output.use_begin(); it != first_output.use_end(); + ++it) { + auto consumer = (*it).owner(); + if (visited.count(consumer)) { + continue; + } + if (second_node_set.find(consumer) != second_node_set.end()) { + visited.insert(consumer); + candidates.push(consumer); + } + if (second->master_nodes.count(consumer)) { + masters_in_consumers.insert(consumer); + } + } + } + if (!masters_in_consumers.empty()) { + bool flag = true; + auto first_node_shape = + GetValueShape((*first->master_nodes.begin())->result(0)); + int32_t size_first = phi::product(first_node_shape); + + for (::pir::Operation* master : masters_in_consumers) { + auto second_node_shape = GetValueShape(master->result(0)); + int32_t size_second = phi::product(second_node_shape); + if (size_first != size_second) { + flag = false; + break; + } + } + if (flag) { + return true; + } + } + + // if reduce using block_reduce, can't fuse producer. + ::pir::Operation* reducer = nullptr; + for (auto& node : second->master_nodes) { + if (GetOpKind(node->name()) == kReduction) { + reducer = node; + break; + } + } + // CHECK(reducer) << "Can't find reduce op in group " << second->group_id; + + // If the elementwise's output should be fetched, the output var cannot be + // computed inline into reduce's loop, in other words, the elementwise's + // cannot fused into reduce's loop Like: group1 = {cast_0}, + // group2={broadcast_0 -> elementwise_0 -> cast_1 -> reduce_max_0} + + // TODO(phlrain) : pass output node set + // if (helper->output_nodes_set_.count(*first->master_nodes.begin())) { + // return false; + // } + + auto input_shape = GetValueShape(reducer->operand_source(0)); + std::vector reduce_axes = GetVectorAttr(reducer, "axis"); + + // int max_num_threads = helper->target_.max_num_threads(); + int max_num_threads = 1000; + // if without last dimension in reduce. + int lane = 1; + if (WithoutLastDimInReduce(input_shape, reduce_axes)) { + for (int idx = reduce_axes.back() + 1; idx < input_shape.size(); ++idx) { + lane *= input_shape[idx]; + } + if (lane > max_num_threads / 2) { + return true; + } + } + + int index = reduce_axes.size() - 1; + for (; index >= 0; --index) { + if (static_cast(index + 1) < reduce_axes.size() && + reduce_axes[index] + 1 != reduce_axes[index + 1]) { + break; + } + lane *= input_shape[reduce_axes[index]]; + if (lane > max_num_threads / 2) { + break; + } + } + + if (lane <= max_num_threads) { + return true; + } else { + int prefix = input_shape[reduce_axes[index]]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; + --idx) { + if (prefix % idx == 0) { + return true; + } + } + } + return false; +} + +inline bool broadcast_fuse_reduce(const std::shared_ptr& first, + const std::shared_ptr& second) { + // if same shape with horizontal relation + if (is_same_size(first, second)) { + return true; + } + ::pir::Operation* reducer = nullptr; + for (auto& node : second->master_nodes) { + if (GetOpKind(node->name()) == kReduction) { + reducer = node; + break; + } + } + // CHECK(reducer) << "Can't find reduce op in group " << second->group_id; + + auto input_shape = GetValueShape(reducer->operand_source(0)); + auto input_size = phi::product(input_shape); + + auto output_shape = GetValueShape((*first->master_nodes.begin())->result(0)); + auto output_size = phi::product(output_shape); + + if (input_size == output_size) { + return elementwise_fuse_reduce(first, second); + } + return false; +} + +inline bool reduce_fuse_elementwise(const std::shared_ptr& first, + const std::shared_ptr& second) { + if (!is_same_size(first, second)) { + return false; + } + // if with last axis in reduce, fuse will waste computation resource. + // so use a simple model evaluate the cost. + // TODO(sunli) : cost-model. + return true; +} + +inline bool horizontal_relation(const std::shared_ptr& first, + const std::shared_ptr& second, + const OpPatternKind op_pattern_kind) { + // merge injective + auto merge_nodes_set = [](const std::shared_ptr& group) { + std::unordered_set<::pir::Operation*> nodes_set = group->nodes_set; + for (auto& sub_group : group->fused_sub_groups) { + nodes_set.insert(sub_group->nodes_set.begin(), + sub_group->nodes_set.end()); + } + return nodes_set; + }; + auto first_set = merge_nodes_set(first); + auto second_set = merge_nodes_set(second); + + auto select_node_set = [](const std::unordered_set<::pir::Operation*>& nodes, + OpPatternKind kind) { + std::unordered_set<::pir::Operation*> selected; + for (auto node : nodes) { + if (GetOpKind(node->name()) == kind) { + selected.insert(node); + } + } + return selected; + }; + auto selected_nodes = select_node_set(second_set, op_pattern_kind); + + auto check_depency = [&](::pir::Operation* node) { + std::queue<::pir::Operation*> candidates; + std::unordered_set<::pir::Operation*> visited_set; + candidates.push(node); + + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + // visit all producer node + // Get all the input Op + for (size_t i = 0; i < candidate->num_operands(); ++i) { + auto producer = + candidate->operand_source(i).dyn_cast().owner(); + // check dependency. + if (first_set.count(producer)) { + return true; + } + // check node is in region. + if (!second_set.count(producer)) { + continue; + } + // recorded visited node. + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + + return false; + }; + + for (auto node : selected_nodes) { + if (check_depency(node)) { + return false; + } + } + + return true; +} + +inline bool horizontal_with_injective( + const std::shared_ptr& first, + const std::shared_ptr& second) { + if (is_const_group(first)) { + return true; + } + + if (!is_same_size(first, second)) { + return false; + } + return horizontal_relation(first, second, kInjective); +} + +inline bool injective_horizontal_with_reduce( + const std::shared_ptr& first, + const std::shared_ptr& second) { + // check injective with injective. + if (!horizontal_relation(first, second, kInjective)) { + return false; + } + return elementwise_fuse_reduce(first, second); +} + +inline bool reduce_fuse_broadcast(const std::shared_ptr& first, + const std::shared_ptr& second) { + // if same shape with horizontal relation + if (is_same_size(first, second)) { + return true; + } + + // Traversing all reducers in all producers requires two types of conditions + // to be met. The first type is the condition that the reducer itself needs to + // meet, and the second type is the condition that the relationship between + // each reducer and its consumers with type of Broadcast needs to meet. It is + // required that each consumer of type Broadcast meet the same shape after + // broadcast as before reduce. + for (auto& node_in_master : first->master_nodes) { + if (GetOpKind(node_in_master->name()) != kReduction) { + continue; + } + ::pir::Operation* reducer = node_in_master; + // First type conditions + // Get some reduce information + auto reducer_input_shape = + phi::vectorize(GetValueShape(reducer->operand_source(0))); + auto reducer_output_shape = + phi::vectorize(GetValueShape(reducer->result(0))); + std::vector reduce_axes = GetVectorAttr(reducer, "axis"); + + auto keep_dim = false; + for (auto& axis : reduce_axes) { + if (axis == -1) { + axis = reducer_input_shape.size() - 1; + } + } + // Check if the reduce axes are continuous + int reduce_size = reducer_input_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= reducer_input_shape[idx - 1]; + } + // Check if the reduce size exceeds the hardware limit + // if (helper->target_ == common::DefaultNVGPUTarget() && + // reduce_size > helper->target_.max_num_threads()) { + // return false; + // } + + // Second type conditions + // Find directly or indirectly consumers with type of Broadcast in the + // second group + auto find_broadcasters_in_descendants = [&](::pir::Operation* producer) + -> std::unordered_set<::pir::Operation*> { + std::queue<::pir::Operation*> candidates; + std::unordered_set<::pir::Operation*> visited_set; + std::unordered_set<::pir::Operation*> broadcasters; + candidates.push(producer); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + // TODO(phlrain) : why only deal with first output + auto first_output = candidate->result(0); + for (auto it = first_output.use_begin(); it != first_output.use_end(); + ++it) { + auto consumer = (*it).owner(); + + if (!visited_set.count(consumer)) { + visited_set.insert(consumer); + candidates.push(consumer); + } + if (GetOpKind(consumer->name()) == kBroadcast && + second->NodeSet().find(consumer) != second->NodeSet().end()) { + broadcasters.insert(consumer); + } + } + } + + return broadcasters; + }; + + // Check if each broadcast node meets the conditions + std::unordered_set<::pir::Operation*> broadcasters_in_consumers = + find_broadcasters_in_descendants(reducer); + for (auto broadcaster : broadcasters_in_consumers) { + // auto = absl::get>( + // broadcaster->attrs.attr_store.at("out_shape")); + + // auto broadcast_axes = absl::get>( + // broadcaster->attrs.attr_store.at("broadcast_axes")); + // TODO(phlrain) : suport here + std::vector broadcaster_output_shape = + GetVectorAttr(broadcaster, "out_shape"); + std::vector broadcast_axes = + GetVectorAttr(broadcaster, "broadcast_axes"); + for (auto& axis : broadcast_axes) { + if (axis == -1) { + axis = broadcaster_output_shape.size() - 1; + } + } + + if (reducer_input_shape != broadcaster_output_shape) { + return false; + } + + if (keep_dim) { + continue; + } else { + // if reducer_output_shape = [1] + if (reducer_output_shape.size() == 1 && reducer_output_shape[0] == 1) { + continue; + } + // check union [reduce_axes, broadcast_axes] = reducer_input_shape + for (size_t idx = 0; idx < reducer_input_shape.size(); ++idx) { + if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == + broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == + reduce_axes.end()) { + return false; + } + } + } + } + } + + return true; +} + +inline bool reduce_fuse_reduce(const std::shared_ptr& first, + const std::shared_ptr& second) { + if (!limit_args(first, second)) { + return false; + } + ::pir::Operation* reducer_0 = nullptr; + for (auto& reducer : first->master_nodes) { + if (GetOpKind(reducer->name()) == kReduction) { + reducer_0 = reducer; + break; + } + } + // CHECK(reducer_0) << "Can't find reduce op in group " << first->group_id; + + ::pir::Operation* reducer_1 = nullptr; + for (auto& reducer : second->master_nodes) { + if (GetOpKind(reducer->name()) == kReduction) { + reducer_1 = reducer; + break; + } + } + CHECK(reducer_1) << "Can't find reduce op in group " << second->group_id; + // check reduce has same input shape and output shape + auto reducer_0_input_shape = GetValueShape(reducer_0->operand_source(0)); + auto reducer_0_output_shape = GetValueShape(reducer_0->result(0)); + + auto reducer_1_input_shape = GetValueShape(reducer_1->operand_source(0)); + auto reducer_1_output_shape = GetValueShape(reducer_1->result(0)); + + // auto reducer_0_reduce_dim = + // absl::get>(reducer_0->attrs.attr_store.at("dim")); + // auto reducer_1_reduce_dim = + // absl::get>(reducer_1->attrs.attr_store.at("dim")); + // TODO(phlrain) + std::vector reducer_0_reduce_dim = GetVectorAttr(reducer_0, "axis"); + std::vector reducer_1_reduce_dim = GetVectorAttr(reducer_1, "axis"); + + for (auto& dim : reducer_0_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_0_reduce_dim.size() - 1; + } + } + + for (auto& dim : reducer_1_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim == -1) { + dim = reducer_1_reduce_dim.size() - 1; + } + } + + // check shape is same + if (reducer_0_input_shape == reducer_1_input_shape && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + for (auto* master : fusion_group->master_nodes) { + if (GetOpKind(master->name()) == kReduction) { + shared_size += GetSharedSize(master); + } + } + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + if (WithoutLastDimInReduce(reducer_0_input_shape, reducer_0_reduce_dim) && + WithoutLastDimInReduce(reducer_1_input_shape, reducer_1_reduce_dim) && + reducer_0_output_shape == reducer_1_output_shape && + reducer_0_reduce_dim == reducer_1_reduce_dim) { + auto shared_size = 0; + for (auto& fusion_group : {first, second}) { + for (auto* master : fusion_group->master_nodes) { + if (GetOpKind(master->name()) == kReduction) { + shared_size += GetSharedSize(master); + } + } + } + +#define MAX_AVAILABLE_SHREAD 32 * 1024 + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } +#undef MAX_AVAILABLE_SHREAD + return true; + } + + return false; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_group.h b/paddle/cinn/hlir/dialect/operator/transforms/op_group.h new file mode 100644 index 00000000000000..87138df17be85b --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_group.h @@ -0,0 +1,195 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_node.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class OpGroup { + public: + explicit OpGroup(const std::shared_ptr& group) : group_(group) {} + + OpGroup(const OpGroup& other) = default; + + using Comparator = ir::Group::SharedGroupComparator; + using Hasher = ir::Group::SharedGroupHasher; + + class OpGroupListIterator { + public: + OpGroupListIterator(std::unordered_set, + Hasher, + Comparator>::const_iterator it) + : iter_(it) {} + + OpGroupListIterator& operator++() { + ++iter_; + return *this; + } + + OpGroupListIterator operator++(int) { + OpGroupListIterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const OpGroupListIterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const OpGroupListIterator& other) const { + return !(*this == other); + } + + OpGroup operator*() const { return OpGroup(*iter_); } + + private: + std::unordered_set, Hasher, Comparator>:: + const_iterator iter_; + }; + + class ProducerOpGroupListView { + public: + explicit ProducerOpGroupListView(const std::weak_ptr& group) + : group_(group) {} + + ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete; + ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete; + + ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) = + delete; + + using const_iterator = OpGroupListIterator; + + size_t size() const { + CHECK(group_.lock()); + return group_.lock()->producer_groups().size(); + } + + const_iterator begin() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->producer_groups().begin()); + } + + const_iterator end() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->producer_groups().end()); + } + + private: + const std::weak_ptr group_; + }; + + class ConsumerOpGroupListView { + public: + explicit ConsumerOpGroupListView(const std::weak_ptr& group) + : group_(group) {} + + ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete; + ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete; + + ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) = + delete; + + using const_iterator = OpGroupListIterator; + + size_t size() const { + CHECK(group_.lock()); + return group_.lock()->consumer_groups().size(); + } + + const_iterator begin() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->consumer_groups().begin()); + } + + const_iterator end() const { + CHECK(group_.lock()); + return const_iterator(group_.lock()->consumer_groups().end()); + } + + private: + const std::weak_ptr group_; + }; + + const std::string& group_id() const { return group_.lock()->group_id; } + + OpPatternKind kind() const { return group_.lock()->kind(); } + + // The WalkOpNodes function is used to traverse the op_nodes in the group and + // execute the VisitOpNode function for each OpNode. This function is + // equivalent to for loop for op_nodes in graph. + // + // In order to avoid unnecessary memory copies, we use WalkOpNodes function + // instead of providing a function to get all op_nodes directly. + // + // Example: Get the all Reduction op_nodes in the group. + // OpGroup group = ...; + // std::set reduce_ op_set; + // // The lambda funtion of VisitOpNode to get reduction op_nodes. + // auto get_reduce_op = [&reduce_op_set](const cinn::dialect::ir::OpNode& + // op){ + // if (op.kind() == OpPatternKind::kReduction) { + // reduce_op_set.insert(op); + // } + // }; + // group.WalkOpNodes(get_reduce_op); + void WalkOpNodes( + const std::function& VisitOpNode) const { + group_.lock()->WalkNodes( + [&](::pir::Operation* node) { VisitOpNode(OpNode(node)); }); + } + + ProducerOpGroupListView producers() const { + return ProducerOpGroupListView(group_); + } + + ConsumerOpGroupListView consumers() const { + return ConsumerOpGroupListView(group_); + } + + std::shared_ptr GetGroup() const { return group_.lock(); } + + bool operator==(const OpGroup& other) const { + return group_.lock().get() == other.group_.lock().get(); + } + + bool operator<(const OpGroup& other) const { + return group_.lock().get() < other.group_.lock().get(); + } + + private: + const std::weak_ptr group_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn + +namespace std { + +template <> +struct hash { + size_t operator()(const cinn::dialect::ir::OpGroup& obj) const { + return std::hash()(reinterpret_cast(obj.GetGroup().get())); + } +}; + +} // namespace std diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_node.h b/paddle/cinn/hlir/dialect/operator/transforms/op_node.h new file mode 100644 index 00000000000000..8579d11b19bb96 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_node.h @@ -0,0 +1,168 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/pir/core/operation.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class OpNode { + public: + explicit OpNode(::pir::Operation* node) + : node_(node), input_tensors_(node), output_tensors_(node) {} + + OpPatternKind kind() const { + auto kind = GetOpKind(node_->name()); + if (kind == kBroadcast) { + // As binary op was defined as broadcast, actually it should be + // element-wise. + if (node_->name() != "broadcast_to") { + return kElementWise; + } + } + return kind; + } + + class TensorListIterator { + public: + TensorListIterator(size_t index, ::pir::Operation* op) + : iter_(index), op_(op) {} + + TensorListIterator& operator++() { + ++iter_; + return *this; + } + + TensorListIterator operator++(int) { + TensorListIterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const TensorListIterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const TensorListIterator& other) const { + return !(*this == other); + } + + TensorNode operator*() const { + return TensorNode(op_->operand_source(iter_)); + } + + private: + size_t iter_; + ::pir::Operation* op_; + }; + + using const_iterator = TensorListIterator; + + class InputTensorListView { + public: + explicit InputTensorListView(::pir::Operation* op) : op_(op) {} + + // InputTensorListView(const InputTensorListView& other) = delete; + // InputTensorListView(InputTensorListView&& other) = delete; + + // InputTensorListView& operator=(const InputTensorListView& other) = + // delete; + + size_t size() const { return op_->num_operands(); } + + TensorNode operator[](size_t index) const { + return TensorNode(op_->operand_source(index)); + } + + const_iterator begin() const { return const_iterator(0, op_); } + + const_iterator end() const { + return const_iterator(op_->num_operands(), op_); + } + + private: + ::pir::Operation* op_; + }; + + class OutputTensorListView { + public: + explicit OutputTensorListView(::pir::Operation* op) : op_(op) {} + + // OutputTensorListView(const OutputTensorListView& other) = delete; + // OutputTensorListView(OutputTensorListView&& other) = delete; + + // OutputTensorListView& operator=(const OutputTensorListView& other) = + // delete; + + size_t size() const { return op_->num_results(); } + + TensorNode operator[](size_t index) const { + return TensorNode(op_->result(index)); + } + + const_iterator begin() const { return const_iterator(0, op_); } + + const_iterator end() const { + return const_iterator(op_->num_results(), op_); + } + + private: + ::pir::Operation* op_; + }; + + bool operator==(const OpNode& other) const { return node_ == other.node_; } + + bool operator<(const OpNode& other) const { return node_ < other.node_; } + + const InputTensorListView& inputs() const { return input_tensors_; } + + const OutputTensorListView& outputs() const { return output_tensors_; } + + template + const T& GetAttr(const std::string& attr_name) const { + auto attr = + paddle::dialect::GetAttributeData(node_->attributes().at(attr_name)); + return PADDLE_GET_CONST(T, attr); + } + + private: + friend struct std::hash; + + ::pir::Operation* node_; + + const InputTensorListView input_tensors_; + const OutputTensorListView output_tensors_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn + +namespace std { + +template <> +struct hash { + size_t operator()(const cinn::dialect::ir::OpNode& obj) const { + return std::hash()(reinterpret_cast(obj.node_)); + } +}; + +} // namespace std diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc new file mode 100644 index 00000000000000..3039d81ff83a35 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.cc @@ -0,0 +1,528 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" + +namespace cinn { +namespace dialect { +namespace ir { + +std::unordered_map OpKindMap = { + {"pd_op.add", OpPatternKind::kElementWise}, + {"pd_op.subtract", OpPatternKind::kElementWise}, + {"pd_op.multiply", OpPatternKind::kElementWise}, + {"pd_op.divide", OpPatternKind::kElementWise}, + {"pd_op.sqrt", OpPatternKind::kElementWise}, + {"pd_op.full", OpPatternKind::kElementWise}, + {"pd_op.relu", OpPatternKind::kElementWise}, + {"pd_op.exp", OpPatternKind::kElementWise}, + {"pd_op.sum", OpPatternKind::kReduction}, + {"cinn_op.reduce_sum", OpPatternKind::kReduction}, + {"cinn_op.reduce_max", OpPatternKind::kReduction}, + {"cinn_op.broadcast", OpPatternKind::kBroadcast}, +}; + +OpPatternKind GetOpKind(const std::string& op_name) { + auto found_it = OpKindMap.find(op_name); + if (found_it == OpKindMap.end()) { + throw std::runtime_error("not support op yet in op kind map"); + } + + return found_it->second; +} + +phi::DDim GetFirstInputShape(const ::pir::Operation* op) { + auto in = op->operand_source(0); + + return in.type().dyn_cast().dims(); +} + +phi::DDim GetValueShape(const ::pir::Value value) { + return value.type().dyn_cast().dims(); +} + +bool WithoutLastDimInReduce(const std::vector& inshape, + const std::vector& axes) { + // if last axis is in reduce. + if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || + std::find(axes.begin(), axes.end(), -1) != axes.end()) { + return false; + } + + int64_t sum_last_axes = 1; + for (size_t idx = axes.back() + 1; idx < inshape.size(); ++idx) { + sum_last_axes *= inshape[idx]; + } + + if (sum_last_axes > 1) { + return true; + } else { + return false; + } +} + +int GetSharedSize(::pir::Operation* node) { + auto inshape = phi::vectorize(GetValueShape(node->result(0))); + + auto axes = GetVectorAttr(node, "axis"); + + if (WithoutLastDimInReduce(inshape, axes)) { + int lane = 1; + for (size_t idx = axes.back() + 1; idx < inshape.size(); ++idx) { + lane = inshape[idx]; + } + // int max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + // todo(phlrain): get gpu max threads + int max_num_threads = 2048; + if (lane > max_num_threads / 2) { + return 0; + } + int index = axes.size() - 1; + for (; index >= 0; --index) { + if (static_cast(index + 1) < axes.size() && + axes[index] != axes[index + 1] - 1) { + break; + } + lane *= inshape[axes[index]]; + if (lane > max_num_threads / 2) { + break; + } + } + // if lane > (max_num_threads / 2),the loop break from lane > + // max_num_threads / 2. + int axis = lane > (max_num_threads / 2) ? axes[index] : axes[index + 1]; + if (lane <= max_num_threads) { + return lane * sizeof(float); + } else { + int prefix = inshape[axis]; + int tail = lane / prefix; + for (int idx = max_num_threads / tail; + idx > ((max_num_threads / 2) / tail); + --idx) { + if (prefix % idx == 0) { + return idx * tail * sizeof(float); + } + } + int num = max_num_threads / tail; + return num * tail * sizeof(float); + } + } + return 0; +} + +using ConditionFunction = + std::function; + +// Op Fusion Pass which performs Ops fusion, Ops are fused +// "vertically", meaning producing Ops are fused into their consumers +// with the intent that the loops which compute their values will be fused in +// code generation. +class OpFusionPassHelper { + public: + explicit OpFusionPassHelper(const ::pir::Program& graph) { + // init fusion relation + InitFusionRelation(); + // filter node data, create group for each node + // auto nodes_inorder = std::get<0>(graph->topological_order()); + + for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) { + auto node = *it; + local_ops_.insert(node); + } + + int index = 0; + for (auto it = graph.block()->begin(); it != graph.block()->end(); ++it) { + auto node = *it; + if (node) { + nodes_.push_back(node); + auto group = std::make_shared(); + // init group + group->nodes.push_back(node); + group->nodes_set.insert(node); + group->output_nodes.insert(node); + // input node + + for (size_t i = 0; i < node->num_operands(); ++i) { + auto input = + node->operand_source(i).dyn_cast().owner(); + if (input && (local_ops_.count(input))) { + group->input_nodes[input] = 1; + } + } + + // group type + group->op_pattern_kind = GetOpKind(node->name()); + // use current node as master node for schedule + group->master_nodes.insert(node); + + // get opration unique id + group->group_id = "id_" + std::to_string(index++); + fusion_groups_[node] = group; + } + } + // reverse node for output to input + std::reverse(nodes_.begin(), nodes_.end()); + } + + // return a vector of groups in topological order. + GroupList operator()(bool do_fusion = true) { + // do op fusion. + if (do_fusion) { + DoOpFusion(); + } + + // find all fusion group. + GroupList fusion_groups; + std::unordered_set groups_set; + for (auto node : nodes_) { + auto& group = fusion_groups_[node]; + if (!groups_set.count(group.get())) { + groups_set.insert(group.get()); + fusion_groups.push_back(group); + // reverse nodes order to producer->consumer. + std::reverse(group->nodes.begin(), group->nodes.end()); + } + } + + // producer consumer + for (auto& consumer : fusion_groups) { + for (auto& input_node : consumer->input_nodes) { + if (!local_ops_.count(input_node.first)) { + continue; + } + auto& producer = fusion_groups_[input_node.first]; + consumer->mut_producer_groups()->insert(producer); + producer->mut_consumer_groups()->insert(consumer); + } + } + + // init group depth. + for (auto& group : fusion_groups) { + for (const auto& consumer : group->consumer_groups()) { + // update depth. + group->depth = std::max(group->depth, consumer->depth + 1); + } + } + + // reverse to keep fusion group in order. + std::reverse(fusion_groups.begin(), fusion_groups.end()); + + return fusion_groups; + } + + private: + void DoOpFusion() { + for (auto consumer : nodes_) { + auto consumer_kind = GetOpKind(consumer->name()); + // kNonFusible op can't fuse any other op. + if (consumer_kind == kNonFusible) { + continue; + } + + // fusion op for consumer + auto consumer_fusion = fusion_groups_[consumer]; // + // check all linkin node + for (size_t i = 0; i < consumer->num_operands(); ++i) { + auto producer_data = consumer->operand_source(i); + + auto producer = producer_data.dyn_cast().owner(); + if (!local_ops_.count(producer)) { + continue; + } + + // if producer is fused. + if (consumer_fusion->nodes_set.count(producer)) { + // VLOG(3) << "Op " << producer->id() << " is fused."; + continue; + } + // if producer data is placeholder + if (!producer) { + continue; + } + // kNonFusible op can't fuse any other op. + auto producer_kind = GetOpKind(producer->name()); + if (producer_kind == kNonFusible) { + continue; + } + // VLOG(3) << "Producer Op: " << producer->id() + // << ", Op Pattern: " << producer_kind + // << " -> Consumer Op: " << consumer->id() + // << ", Op Pattern: " << consumer_kind; + bool can_fuse = true; + // checkout producer node outputs are all in fusion op + + // find all the op use by + size_t producer_data_used_num = 0; + for (auto it = producer_data.use_begin(); it != producer_data.use_end(); + ++it) { + auto consumer_node = it->owner(); + producer_data_used_num++; + // if fusion group can't find node, can't merge + if (consumer_fusion->nodes_set.find(consumer_node) == + consumer_fusion->nodes_set.end()) { + can_fuse = false; + break; + } + } + + if (!can_fuse || !CanFuse(producer, consumer)) continue; + // VLOG(3) << "Fuse Op " << producer->id() << " into Op " + // << consumer->id(); + + // fuse producer to fusion group + // TODO(phrain) : support id + // consumer_fusion->group_id = + // producer->id() + "_" + consumer_fusion->group_id; + + consumer_fusion->group_id = consumer_fusion->group_id; + consumer_fusion->nodes.push_back(producer); + consumer_fusion->nodes_set.insert(producer); + consumer_fusion->input_nodes.erase(producer); + consumer_fusion->op_pattern_kind = + static_cast(consumer_fusion->op_pattern_kind) > + static_cast(producer_kind) + ? consumer_fusion->op_pattern_kind + : producer_kind; + + if (producer_kind == kReduction) { + consumer_fusion->master_nodes.insert(producer); + } + + if (output_nodes_set_.count(producer)) { + // VLOG(3) << "Insert Global Output Node : " << producer->id(); + consumer_fusion->output_nodes.insert(producer); + } else if (producer_data_used_num > 1 && producer->num_operands() > 0 && + is_same_size(producer, consumer_fusion)) { + // producer is not a const value node. + consumer_fusion->internal_nodes.insert(producer); + } + + // fuse input node + + auto producer_fusion = fusion_groups_[producer]; + for (auto input_node : producer_fusion->input_nodes) { + if (consumer_fusion->input_nodes.count(input_node.first)) { + consumer_fusion->input_nodes[input_node.first] += input_node.second; + } else { + consumer_fusion->input_nodes.insert(input_node); + } + } + // update node group + fusion_groups_[producer] = consumer_fusion; + } + } + } + + void InitFusionRelation() { + // fusion relation. + // 1.kElementwise as producer + { + FusionRelation relation; + // producer -> consumer + relation.op_kind = {kElementWise, kBroadcast, kReduction, kInjective}; + // producer -> fusion + relation.fusion_op_kind = { + // horizontal or vertical relation(Elementwise + *Elementwise*). As + // has same output shape, can always fuse. + {kElementWise, always_fuse}, + // must be horizontal, as Elementwise + Broadcast is left to fusion + // merge pass. + {kBroadcast, + [](::pir::Operation* producer, const GroupPtr& consumer) -> bool { + // NOTE, producer and consumer NEVER be same size + if (is_same_size(producer, consumer)) { + return true; + } + + // NOTE, original code is below, if produer is not output node, + // result always be true + // !helper->output_nodes_set_.count(producer); + return true; + }}, + // horizontal or vertical relation, check with same output shape with + // horizontal relation or with last + // successive dimension less than 1024 for gpu. + {kReduction, horizontal_or_vertical_reduce_relation}, + // can be horizontal or can compute inline, check with same output + // shape or can compute inline. + {kInjective, horizontal_or_can_inline}, + // must be horizontal, check with same output shape. + {kOutFusible, is_same_shape}}; + fusion_relation_map_[kElementWise] = std::move(relation); + } + // 2.kBroadcast as producer + { + FusionRelation relation; + // producer -> consumer + relation.op_kind = {kElementWise, kReduction, kInjective}; + // producer -> fusion + relation.fusion_op_kind = { + // horizontal or vertical relation(Broadcast + *Elementwise*), check + // with same output shape. + {kElementWise, is_same_size}, + // must be horizontal, as Broadcast + Broadcast is not allowed. + {kBroadcast, is_same_size}, + // horizontal or vertical relation(Broadcast + Reduce). + {kReduction, horizontal_or_vertical_reduce_relation}, + // can be horizontal or can compute inline, check with same output + // shape or just one consumer. + {kInjective, horizontal_or_can_inline}, + // must be horizontal, check with same output shape. + {kOutFusible, is_same_shape}}; + fusion_relation_map_[kBroadcast] = std::move(relation); + } + // 3.kReduction as producer + { + FusionRelation relation; + // producer -> consumer + relation.op_kind = {kElementWise, kBroadcast}; + // producer -> fusion + relation.fusion_op_kind = { + // horizontal or vertical relation(Reduce + Elementwise*), check + // without last dimension in reduce. + {kElementWise, is_same_size}, + // must be horizontal relation, check with same output shape and + // without last dimension in reduce. + {kBroadcast, reduce_fuse_broadcast}, + // must be horizontal relation and with same reduce attr. + {kReduction, reduce_fuse_reduce}, + // no_fuse + {kInjective, no_fuse}, + // can't fuse. + {kOutFusible, no_fuse}}; + fusion_relation_map_[kReduction] = std::move(relation); + } + // 4.kInjective + { + FusionRelation relation; + // producer -> consumer + relation.op_kind = {kElementWise, kInjective}; + // producer -> fusion + relation.fusion_op_kind = { + // can be horizontal or vertical(Injective + Elementwise), check with + // same output shape. + {kElementWise, is_same_size}, + // must be horizontal relation, check with same output shape. + {kBroadcast, horizontal_with_same_size}, + // left to fusion merge pass. + {kReduction, no_fuse}, + // must be horizontal relation, check with same output shape. + {kInjective, horizontal_or_can_inline}, + // can't fuse. + {kOutFusible, no_fuse}, + }; + fusion_relation_map_[kInjective] = std::move(relation); + } + // 5.kOutFusible + { + FusionRelation relation; + // producer -> consumer + relation.op_kind = {kElementWise, kBroadcast}; + // producer -> fusion + relation.fusion_op_kind = { + // horizontal or vertical relation, check has same shape. + {kElementWise, is_same_shape}, + // it must be horizontal relation, check has same shape. + {kBroadcast, is_same_shape}, + // can't fuse. + {kReduction, no_fuse}, + // must be horizontal relation, check has same shape. + {kInjective, is_same_shape}, + // can't fuse. + {kOutFusible, no_fuse}, + }; + fusion_relation_map_[kOutFusible] = std::move(relation); + } + } + + bool CanFuse(::pir::Operation* producer, const ::pir::Operation* consumer) { + auto& relation = fusion_relation_map_[GetOpKind(producer->name())]; + // first step: check producer can be fused into consumer + if (relation.op_kind.count(GetOpKind(consumer->name()))) { + auto& consumer_group = fusion_groups_[consumer]; + // second step: check producer can be fused into consumer group + VLOG(3) << "Call ConditionFunction, Producer Op Pattern : " + << GetOpKind(producer->name()) << " , Consumer Group Pattern : " + << consumer_group->op_pattern_kind; + return relation.fusion_op_kind[consumer_group->op_pattern_kind]( + producer, fusion_groups_[consumer]); + } + + return false; + } + std::vector<::pir::Operation*> nodes_; + std::unordered_map fusion_groups_; + std::unordered_set output_nodes_set_; + + std::vector> groups_; + + std::unordered_set local_ops_; + + struct FusionRelation { + // producer -> consumer + std::unordered_set op_kind = {}; + // producer -> fusion sonsumer + std::unordered_map fusion_op_kind = {}; + }; + std::unordered_map fusion_relation_map_; +}; + +GroupList OpFusionPassInternal(const ::pir::Program& program) { + VLOG(3) << "OpFusionPass...!"; + auto op_fusion_helper = OpFusionPassHelper(program); + auto res = op_fusion_helper(); + + for (size_t i = 0; i < res.size(); ++i) { + auto group = res[i]; + + for (size_t j = 0; j < group->nodes.size(); ++j) { + } + } + + // for (auto& group : graph->fusion_groups) { + // VLOG(3) << "Group Id : " << group->group_id; + // for (const auto& producer : group->producer_groups()) { + // VLOG(3) << " producer group -> " << producer->group_id; + // } + // for (const auto& consumer : group->consumer_groups()) { + // VLOG(3) << " consumer group -> " << consumer->group_id; + // } + // } + VLOG(3) << "OpFusionPass Finish...!"; + + return res; +} + +// void BuildNonFusedGroupsPassInternal(framework::Graph* graph) { +// auto op_fusion_helper = OpFusionPassHelper(graph); +// VLOG(3) << "Apply OpFusionPass to generate initial non-fusion groups"; +// graph->fusion_groups = op_fusion_helper(false); +// } + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h new file mode 100644 index 00000000000000..c784140c1cf363 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h" +#include "paddle/pir/core/program.h" + +namespace cinn { +namespace dialect { +namespace ir { + +using GroupPtr = std::shared_ptr; +using GroupList = std::vector; + +GroupList OpFusionPassInternal(const ::pir::Program& program); + +GroupList GeneralFusionMergePassInternal(const ::pir::Program* graph, + const GroupList& group_list); + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h new file mode 100644 index 00000000000000..1ba6ba85b51588 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_util.h @@ -0,0 +1,587 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +namespace cinn { +namespace dialect { +namespace ir { + +enum OpPatternKind { + // The relation between input tensor index and output tensor index is + // one-to-one correspondence. + // for example :code:`out[i, j] = input[i, j] + 1`. + // Note that the axis need to be in order. + kElementWise = 0, + // The relation between input tensor index and output tensor index is + // one-to-many correspondence. + // for example :code:`out[i, j, k] = input[i, j]`. + // Note that the axis need to be in order. + kBroadcast = 1, + // Injective operator, we can always injectively map a output axis to a input + // axis. + // for example :code:`out[i, j] = input[j, i]`. + kInjective = 2, + // The relation between input tensor index and output tensor index is + // many-to-one correspondence. + // for example :code:`out[i, j] = sum(input[i, j, k]) along k`. + kReduction = 3, + // Complex operation, can still fuse one-to-one operations into its output. + kOutFusible = 4, + // Operation that cannot fuse anything. + kNonFusible = 8 +}; + +OpPatternKind GetOpKind(const std::string& op_name); + +template +std::vector GetVectorAttr(const ::pir::Operation* op, + const std::string& name) { + auto& attr_map = op->attributes(); + PADDLE_ENFORCE( + attr_map.count(name), + phi::errors::PreconditionNotMet( + "attr [%s] MUST in attribute map for [%s] op", name, op->name())); + auto& val = attr_map.at(name); + + PADDLE_ENFORCE(val.isa<::pir::ArrayAttribute>(), + phi::errors::PreconditionNotMet( + "axis Type MUST ArrayAttribute for [%s] op", op->name())); + auto array_list = val.dyn_cast<::pir::ArrayAttribute>().AsVector(); + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ(array_list[0].isa<::pir::Int64Attribute>(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be ir::Int64Attribute")); + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back(array_list[i].dyn_cast<::pir::Int64Attribute>().data()); + } + } + return vec_res; +} + +struct Group { + Group() = default; + + // distance to last group. + int depth{0}; + int max_depth{0}; + int min_depth{INT_MAX}; + // group id, consisted of node's id. + std::string group_id{""}; + // global unique id. + std::string unique_id{"uniq"}; + // node in this group + std::vector<::pir::Operation*> nodes; + std::unordered_set<::pir::Operation*> nodes_set; + // input nodes of the group. + std::unordered_map<::pir::Operation*, int> input_nodes; + // output nodes of the group. + std::unordered_set<::pir::Operation*> output_nodes; + // op pattern kind. + OpPatternKind op_pattern_kind{kElementWise}; + // internal node, the output is used by multi-node. + // internal node can't use compute inline, should use buffer. + std::unordered_set<::pir::Operation*> internal_nodes; + // master node for schedule + std::unordered_set<::pir::Operation*> master_nodes; + + // fused sub-groups, used for fusion merge pass + std::vector> fused_sub_groups; + // if as sub-group, used for belong groups. + std::unordered_set> belong_groups; + + // for op lowering. + std::vector input_names; + std::vector output_names; + + struct SharedGroupHasher { + size_t operator()(const std::shared_ptr& group) const noexcept { + return std::hash()(reinterpret_cast(group.get())); + } + }; + struct SharedGroupComparator { + bool operator()(const std::shared_ptr& first, + const std::shared_ptr& second) const noexcept { + return first.get() == second.get(); + } + }; + + std::vector<::pir::Operation*> CollectNodes() { + if (fused_sub_groups.size()) { + std::vector<::pir::Operation*> tmp_nodes; + for (auto& group : fused_sub_groups) { + tmp_nodes.insert( + tmp_nodes.end(), group->nodes.begin(), group->nodes.end()); + } + return tmp_nodes; + } else { + return nodes; + } + } + + void WalkNodes( + const std::function& VisitNode) const { + if (fused_sub_groups.size()) { + for (auto& group : fused_sub_groups) { + for (const auto& node : group->nodes) { + VisitNode(node); + } + } + } else { + for (const auto& node : nodes) { + VisitNode(node); + } + } + } + + std::unordered_set<::pir::Operation*> NodeSet() { + std::unordered_set<::pir::Operation*> node_set; + for (auto node : CollectNodes()) { + node_set.insert(node); + } + return node_set; + } + + // TODO(phlrain) : impliment GetInputNodeDatas GetOutputNodeDatas func + // std::unordered_set<::pir::Value> GetInputNodeDatas() { return {}; } + // std::unordered_set<::pir::Value> GetOutputNodeDatas() { return {}; } + + std::string GetFuncName() { return "fn_" + group_id + unique_id; } + + public: + const std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>& + producer_groups() const { + return producer_groups_; + } + + const std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>& + consumer_groups() const { + return consumer_groups_; + } + + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>* + mut_producer_groups() { + return &producer_groups_; + } + + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator>* + mut_consumer_groups() { + return &consumer_groups_; + } + + OpPatternKind kind() const { return op_pattern_kind; } + + private: + // input groups + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + producer_groups_; + // output grous + std::unordered_set, + SharedGroupHasher, + SharedGroupComparator> + consumer_groups_; +}; + +phi::DDim GetFirstInputShape(const ::pir::Operation* op); + +phi::DDim GetValueShape(const ::pir::Value value); + +bool WithoutLastDimInReduce(const std::vector& inshape, + const std::vector& axes); + +int GetSharedSize(::pir::Operation* node); + +inline bool always_fuse(::pir::Operation* producer, + const std::shared_ptr& consumer) { + return true; +} + +inline bool no_fuse(::pir::Operation* producer, + const std::shared_ptr& consumer) { + return false; +} + +inline bool is_same_shape(::pir::Operation* producer, + const std::shared_ptr& consumer) { + auto master_node = consumer->master_nodes.begin(); + return GetValueShape(producer->result(0)) == + GetValueShape((*master_node)->result(0)); +} + +inline bool is_same_size(::pir::Operation* producer, + const std::shared_ptr& consumer) { + auto master_node = consumer->master_nodes.begin(); + auto producer_shape = GetValueShape(producer->result(0)); + auto consumer_shape = GetValueShape((*master_node)->result(0)); + if (producer_shape == consumer_shape) { + return true; + } + auto psize = phi::product(producer_shape); + auto csize = phi::product(consumer_shape); + return psize == csize; +} + +inline bool without_last_dimension_in_reduce( + ::pir::Operation* producer, const std::shared_ptr& consumer) { + auto in_shape = phi::vectorize(GetFirstInputShape(producer)); + auto reduce_axes = GetVectorAttr(producer, "axis"); + return WithoutLastDimInReduce(in_shape, reduce_axes); +} + +inline bool reduce_fuse_reduce(::pir::Operation* producer, + const std::shared_ptr& consumer) { + ::pir::Operation* reducer = NULL; + for (auto* master : consumer->master_nodes) { + if (GetOpKind(master->name()) == kReduction) { + reducer = master; + break; + } + } + // check reduce has same input shape and output shape + auto producer_input_shape = + phi::vectorize(GetValueShape(producer->operand_source(0))); + auto producer_output_shape = + phi::vectorize(GetValueShape(producer->result(0))); + + auto reducer_input_shape = + phi::vectorize(GetValueShape(reducer->operand_source(0))); + auto reducer_output_shape = + phi::vectorize(GetValueShape(reducer->result(0))); + + auto producer_reduce_dim = GetVectorAttr(producer, "axis"); + auto reducer_reduce_dim = GetVectorAttr(reducer, "axis"); + + for (auto& dim : producer_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim < 0) { + dim += producer_input_shape.size(); + } + } + + for (auto& dim : reducer_reduce_dim) { + // if dim = -1, set as shape.size() - 1 + if (dim < 0) { + dim += reducer_input_shape.size(); + } + } + + if (producer_output_shape == reducer_output_shape && + producer_reduce_dim == reducer_reduce_dim) { + bool input_shape_same = producer_input_shape == reducer_input_shape; + bool without_last_dim = + WithoutLastDimInReduce(producer_input_shape, producer_reduce_dim) && + WithoutLastDimInReduce(reducer_input_shape, reducer_reduce_dim); + // check shape is same + if (input_shape_same || without_last_dim) { + auto shared_size = GetSharedSize(producer); + for (auto* master : consumer->master_nodes) { + if (GetOpKind(master->name()) == kReduction) { + shared_size += GetSharedSize(master); + } + } + + constexpr int MAX_AVAILABLE_SHREAD = 32 * 1024; + if (shared_size > MAX_AVAILABLE_SHREAD) { + return false; + } + return true; + } + } + + return false; +} + +inline bool is_horizontal_relation(::pir::Operation* producer, + const std::shared_ptr& consumer) { + auto check_depency = [&](::pir::Operation* node) { + std::queue<::pir::Operation*> candidates; + std::unordered_set<::pir::Operation*> visited_set; + candidates.push(node); + + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + // visit all producer node + for (size_t i = 0; i < candidate->num_operands(); ++i) { + auto tmp_node = + candidate->operand_source(i).dyn_cast().owner(); + // check depency. + if (producer == tmp_node) { + return true; + } + // check node is in region. + if (!consumer->nodes_set.count(tmp_node)) { + continue; + } + // recored visited node. + if (!visited_set.count(tmp_node)) { + visited_set.insert(tmp_node); + candidates.push(tmp_node); + } + } + } + + return false; + }; + + for (auto node : consumer->nodes_set) { + if (GetOpKind(node->name()) != consumer->op_pattern_kind) { + continue; + } + if (check_depency(node)) { + return false; + } + } + + return true; +} + +inline bool horizontal_or_vertical_reduce_relation( + ::pir::Operation* producer, const std::shared_ptr& consumer) { + // check is same shape with horizontal relation. + if (is_same_size(producer, consumer)) { + return true; + } + + // reducer node in fusion op. + ::pir::Operation* reducer = NULL; + for (auto* master : consumer->master_nodes) { + if (GetOpKind(master->name()) == kReduction) { + reducer = master; + break; + } + } + + // check producer has same shape with reducer node. + auto reduce_shape = phi::vectorize(GetFirstInputShape(reducer)); + auto reduce_axes = GetVectorAttr(reducer, "axis"); + + for (auto& axis : reduce_axes) { + // if axis = -1, set as shape.size() - 1 + if (axis < 0) { + axis += reduce_shape.size(); + } + } + + auto node_shape = phi::vectorize(GetFirstInputShape(producer)); + auto node_size = std::accumulate( + node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto reduce_size = std::accumulate( + reduce_shape.begin(), reduce_shape.end(), 1, std::multiplies()); + + // is not same size with reduce size. + if (node_size != reduce_size) { + return false; + } + // check without last axis in reduce. + if (WithoutLastDimInReduce(reduce_shape, reduce_axes)) { + return false; + } + + int succesive_reduce_dimension = reduce_shape.at(reduce_axes.back()); + for (int idx = reduce_axes.size() - 2; idx >= 0; --idx) { + if (reduce_axes[idx] == reduce_axes[idx + 1] - 1) { + succesive_reduce_dimension *= reduce_shape[reduce_axes[idx]]; + continue; + } + break; + } + + // helper->target_ == common::DefaultNVGPUTarget() + // succesive_reduce_dimension <= helper->target_.max_num_threads() + // TODO(phlrain): support is_gpu_target and max_thread + bool is_gpu_target = true; + int max_thread = 32 * 1024; + return is_gpu_target + ? (succesive_reduce_dimension <= max_thread ? true : false) + : true; +} + +inline bool horizontal_or_can_inline(::pir::Operation* producer, + const std::shared_ptr& consumer) { + // horizontal relation. + if (is_horizontal_relation(producer, consumer)) { + if (is_same_size(producer, consumer)) { + return true; + } else { + // if do broadcast, check can compute inline. + // return helper->output_nodes_set_.count(producer) == 0; + // TODO(phlrain): support output node set check + return false; + } + } + // vertical relation: 1.can compute inline + // if (helper->GetNodeData(producer)->outlinks().size() == 1 && + // helper->output_nodes_set_.count(producer) == 0) { + // return true; + // } + + // link to same node. + // auto& out_links = helper->GetNodeData(producer)->outlinks(); + // for (auto link : out_links) { + // if ((*out_links.begin())->sink() != link->sink()) { + // return false; + // } + // } + + // return helper->output_nodes_set_.count(producer) == 0; + + return false; +} + +inline bool horizontal_with_same_size(::pir::Operation* producer, + const std::shared_ptr& consumer) { + return is_horizontal_relation(producer, consumer) && + is_same_size(producer, consumer); +} + +inline bool reduce_fuse_broadcast(::pir::Operation* producer, + const std::shared_ptr& consumer) { + if (is_horizontal_relation(producer, consumer)) { + if (is_same_size(producer, consumer)) { + return true; + } + return false; + } + + // if (helper->target_ != common::DefaultNVGPUTarget()) { + // return true; + // } + + auto rinput_shape = phi::vectorize(GetFirstInputShape(producer)); + auto reduce_axes = GetVectorAttr(producer, "axis"); + auto keep_dim = producer->attributes() + .at("keep_dim") + .dyn_cast<::pir::BoolAttribute>() + .data(); + for (auto& axis : reduce_axes) { + if (axis < 0) { + axis += rinput_shape.size(); + } + } + + int reduce_size = rinput_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= rinput_shape[idx - 1]; + } + + // if (reduce_size > helper->target_.max_num_threads()) { + // return false; + // } + + auto routput_shape = + phi::vectorize(GetValueShape(producer->result(0))); + auto find_reducer = + [&](::pir::Operation* node, + ::pir::Operation* reducer, + const std::unordered_set<::pir::Operation*>& nodes_set) { + std::queue<::pir::Operation*> candidates; + candidates.push(node); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (size_t i = 0; i < candidate->num_operands(); ++i) { + auto producer = + candidate->operand_source(i).dyn_cast().owner(); + if (producer == reducer) { + return true; + } + + if (nodes_set.count(producer)) { + candidates.push(producer); + } + } + } + + return false; + }; + + for (auto node : consumer->nodes_set) { + if (GetOpKind(node->name()) != kBroadcast) { + continue; + } + + if (!find_reducer(node, producer, consumer->nodes_set)) { + continue; + } + + auto broadcast_shape = GetVectorAttr(node, "out_shape"); + auto broadcast_axes = GetVectorAttr(node, "broadcast_axes"); + + for (auto& axis : broadcast_axes) { + if (axis < 0) { + axis += broadcast_shape.size(); + } + } + + if (rinput_shape != broadcast_shape) { + return false; + } + // if keep dim = true. + if (keep_dim) { + continue; + } else { + // if routput_shape = [1] + if (routput_shape.size() == 1 && routput_shape[0] == 1) { + continue; + } + // check [reduce_axes, axes] = {0, 1, 2, 3, 4, 5, 6, ...} + for (size_t idx = 0; idx < rinput_shape.size(); ++idx) { + // note: !x ^ y == (!x) ^ y == !(x ^ y) + if ((std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) != + broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == + reduce_axes.end()) { + return false; + } + } + continue; + } + return false; + } + return true; +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/tensor_node.cc b/paddle/cinn/hlir/dialect/operator/transforms/tensor_node.cc new file mode 100644 index 00000000000000..0688b513f4497c --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/tensor_node.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h" + +#include "paddle/cinn/hlir/dialect/operator/transforms/op_node.h" + +namespace cinn { +namespace dialect { +namespace ir { + +OpNode TensorNode::producer() const { + return OpNode(node_data_.dyn_cast().owner()); +} + +OpNode TensorNode::ConsumerOpListView::Iterator::operator*() const { + return OpNode(iter_.owner()); +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h b/paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h new file mode 100644 index 00000000000000..c48e476ec2a8d0 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/tensor_node.h @@ -0,0 +1,102 @@ +// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class OpNode; + +class TensorNode final { + public: + TensorNode(::pir::Value value) : node_data_(value), consumers_(value) {} + + // Get the shape of tensor. + const phi::DDim& shape() const { + return node_data_.type() + .dyn_cast() + .dims(); + } + + // Input data has no producer. + bool HasProducer() const { return consumers_.size() != 0; } + + OpNode producer() const; + + class ConsumerOpListView { + public: + explicit ConsumerOpListView(pir::Value data) : node_data_(data) {} + + ConsumerOpListView(const ConsumerOpListView& other) = delete; + ConsumerOpListView(ConsumerOpListView&& other) = delete; + + ConsumerOpListView& operator=(const ConsumerOpListView& other) = delete; + + using UseIterator = ::pir::ValueUseIterator<::pir::OpOperand>; + class Iterator { + public: + explicit Iterator(UseIterator it) : iter_(it) {} + + Iterator& operator++() { + ++iter_; + return *this; + } + + Iterator operator++(int) { + Iterator tmp = *this; + ++iter_; + return tmp; + } + + bool operator==(const Iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const Iterator& other) const { return !(*this == other); } + + OpNode operator*() const; + + private: + UseIterator iter_; + }; + + size_t size() const { return node_data_.use_count(); } + + Iterator begin() const { return Iterator(node_data_.use_begin()); } + + Iterator end() const { return Iterator(node_data_.use_end()); } + + private: + ::pir::Value node_data_; + }; + + const ConsumerOpListView& consumers() const { return consumers_; } + + private: + ::pir::Value node_data_; + + const ConsumerOpListView consumers_; +}; + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt index 6023117faee098..c85931ad954cf3 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt @@ -1,4 +1,10 @@ if(NOT CINN_ONLY) - cinn_cc_library(cinn_runtime_dialect SRCS runtime_dialect.cc jit_kernel_op.cc - DEPS pir_core) + cinn_cc_library( + cinn_runtime_dialect + SRCS + runtime_dialect.cc + jit_kernel_op.cc + DEPS + cinn_op_dialect + pir_core) endif() diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc index ed3d4a4045c595..56f598b55bf525 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc @@ -14,6 +14,8 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/framework/new_ir_compiler.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/enforce.h" @@ -22,20 +24,22 @@ namespace dialect { const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName}; -void JitKernelOp::Verify() { +void JitKernelOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp."; auto& attributes = this->attributes(); - IR_ENFORCE(attributes.count(kAttrName) > 0 && - attributes.at(kAttrName).isa<::pir::PointerAttribute>(), - "Type of attribute: instruction is not right."); + IR_ENFORCE( + attributes.count(kAttrName) > 0 && + attributes.at(kAttrName).isa(), + "Type of attribute: instruction is not right."); } -hlir::framework::Instruction* JitKernelOp::instruction() { - void* ptr = - attributes().at(kAttrName).dyn_cast<::pir::PointerAttribute>().data(); - return reinterpret_cast(ptr); +const hlir::framework::newir::CUDAJITInfo& JitKernelOp::cuda_jit_info() { + return attributes() + .at(kAttrName) + .dyn_cast() + .data(); } } // namespace dialect diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h index f410e4d46c021a..0078d0d3b172d4 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h @@ -14,16 +14,11 @@ #pragma once +#include "paddle/cinn/hlir/framework/new_ir/utils.h" #include "paddle/pir/core/op_base.h" namespace cinn { -namespace hlir { -namespace framework { -class Instruction; -} // namespace framework -} // namespace hlir - namespace dialect { /* @@ -46,12 +41,12 @@ class JitKernelOp : public ::pir::Op { static const char* name() { return "cinn_runtime.jit_kernel"; } // TODO(Aurelius84): Think deeply what should contains static constexpr uint32_t attributes_num = 1; - static constexpr char* kAttrName = "instruction"; + static constexpr char* kAttrName = "jit_info"; static const char* attributes_name[attributes_num]; - hlir::framework::Instruction* instruction(); + const hlir::framework::newir::CUDAJITInfo& cuda_jit_info(); - void Verify(); + void VerifySig(); }; } // namespace dialect diff --git a/paddle/cinn/hlir/framework/CMakeLists.txt b/paddle/cinn/hlir/framework/CMakeLists.txt index 54da1e2b7dc904..1aa6817164a1cf 100755 --- a/paddle/cinn/hlir/framework/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/CMakeLists.txt @@ -29,8 +29,6 @@ gather_srcs( if(NOT CINN_ONLY) cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi pd_op_dialect) - cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi - cinn_op_dialect) endif() if(WITH_CUDA) diff --git a/paddle/cinn/hlir/framework/convert_to_dialect.cc b/paddle/cinn/hlir/framework/convert_to_dialect.cc deleted file mode 100644 index f76b49a54555f9..00000000000000 --- a/paddle/cinn/hlir/framework/convert_to_dialect.cc +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/hlir/framework/convert_to_dialect.h" - -#include -#include - -#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" -#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/program.h" -#include "paddle/pir/core/builtin_attribute.h" -#include "paddle/pir/core/program.h" - -namespace cinn { -namespace hlir { -namespace framework { - -std::unique_ptr<::pir::Program> ConvertToRuntimeDialect( - const hlir::framework::Program& program) { - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - ctx->GetOrRegisterDialect(); - auto ir_program = std::make_unique<::pir::Program>(ctx); - - std::string jit_op_name = dialect::JitKernelOp::name(); - ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); - - auto& instrs = program.GetRunInstructions(); - for (auto& instr : instrs) { - std::unordered_map op_attrs{ - {dialect::JitKernelOp::kAttrName, - ::pir::PointerAttribute::get(ctx, instr.get())}, - }; - - ::pir::Operation* cinn_op = - ::pir::Operation::Create({}, op_attrs, {}, op_info); - ir_program->block()->push_back(cinn_op); - } - return std::move(ir_program); -} - -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc index ea76d939bc45b9..ac43f808e7303e 100644 --- a/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/new_ir/op_lowering_impl.cc @@ -15,6 +15,8 @@ #include "paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h" #include + +#include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/hlir/framework/op_lowering_util.h" #include "paddle/cinn/hlir/op/external_api_registry.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" @@ -409,16 +411,17 @@ std::vector OpLowererImpl::DoOpLower( // 2.Do lower std::string lower_fn_name = CompatibleInfo::OpFuncName(*op); - std::vector funcs = lang::LowerVec(lower_fn_name, - tmp_stages, - *op_func_arg_tensors, - {}, - {}, - nullptr, - this->target_, - true); + ast_gen_ius::TensorGroup tensor_group = + ast_gen_ius::ConvertStageMapToTensorGroup(tmp_stages); + std::vector funcs = lang::LowerToAstVec( + lower_fn_name, *op_func_arg_tensors, {&tensor_group}, this->target_); VLOG(4) << "Lower op: " << lower_fn_name << ", get " << funcs.size() << " LoweredFunc:\n"; + if (VLOG_IS_ON(4)) { + for (auto fun : funcs) { + VLOG(4) << fun; + } + } op_func_arg_tensors->clear(); for (int idx = 0; idx < pack.size() - 1; ++idx) { diff --git a/paddle/cinn/hlir/framework/new_ir/utils.cc b/paddle/cinn/hlir/framework/new_ir/utils.cc index 3f938981390fbc..86cf0e187cc45a 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.cc +++ b/paddle/cinn/hlir/framework/new_ir/utils.cc @@ -20,7 +20,7 @@ namespace framework { namespace newir { const std::unordered_map CompatibleInfo::OP_NAMES = { - {"pd_op.full", "fill_constant"}}; + {"pd_op.full", "fill_constant"}, {"pd_op.add", "elementwise_add"}}; std::string CompatibleInfo::OpName(const ::pir::Operation& op) { std::string name = op.name(); diff --git a/paddle/cinn/hlir/framework/new_ir/utils.h b/paddle/cinn/hlir/framework/new_ir/utils.h index 953dc6672bc18f..755f11fcae2206 100644 --- a/paddle/cinn/hlir/framework/new_ir/utils.h +++ b/paddle/cinn/hlir/framework/new_ir/utils.h @@ -23,6 +23,12 @@ namespace hlir { namespace framework { namespace newir { +struct CUDAJITInfo { + void* fn_ptr; + std::vector block_dims; + std::vector grid_dims; +}; + struct CompatibleInfo { static constexpr char* kNamePrefix = "var_"; // TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.cc b/paddle/cinn/hlir/framework/new_ir_compiler.cc index 2a40531196da4d..fbc4c58a5ed9a9 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.cc +++ b/paddle/cinn/hlir/framework/new_ir_compiler.cc @@ -39,6 +39,44 @@ std::unique_ptr NewIRCompiler::Build() { return std::move(Build(groups)); } +std::vector NewIRCompiler::BuildCUDAJITInfo( + const std::vector& groups) { + std::vector vec_res; + + auto op_lowerer = CreateOpLowerer(target_); + + std::vector> lowered_funcs; + for (int i = 0; i < groups.size(); ++i) { + lowered_funcs.emplace_back(op_lowerer.Lower(groups[i])); + } + + for (auto&& lowered_func : lowered_funcs) { + ProcessFunction(lowered_func); + } + + compiler_ = backends::Compiler::Create(target_); + auto build_module = m_builder_.Build(); + compiler_->Build(build_module, ""); + + auto instructions = BuildInstructions(groups); + + auto fn_ptrs = compiler_->GetFnPtr(); + + for (int idx = 0; idx < groups.size(); ++idx) { + newir::CUDAJITInfo jit_info; + jit_info.fn_ptr = fn_ptrs[idx]; + + lowered_funcs[idx][0]->cuda_axis_info.CopyBlockDimsTo( + &(jit_info.block_dims)); + + lowered_funcs[idx][0]->cuda_axis_info.CopyGridDimsTo(&(jit_info.grid_dims)); + + vec_res.push_back(jit_info); + } + + return vec_res; +} + std::unique_ptr NewIRCompiler::Build( const std::vector& groups) { auto op_lowerer = CreateOpLowerer(target_); diff --git a/paddle/cinn/hlir/framework/new_ir_compiler.h b/paddle/cinn/hlir/framework/new_ir_compiler.h index 62c3d97a21a415..44d92ad1386bf0 100644 --- a/paddle/cinn/hlir/framework/new_ir_compiler.h +++ b/paddle/cinn/hlir/framework/new_ir_compiler.h @@ -40,6 +40,9 @@ class NewIRCompiler final { std::unique_ptr Build(); + std::vector BuildCUDAJITInfo( + const std::vector& groups); + std::unique_ptr Build(const std::vector& groups); private: diff --git a/paddle/fluid/distributed/auto_parallel/dist_attr.cc b/paddle/fluid/distributed/auto_parallel/dist_attr.cc index b7bb47b3b859e9..e6c31d06e21c2d 100644 --- a/paddle/fluid/distributed/auto_parallel/dist_attr.cc +++ b/paddle/fluid/distributed/auto_parallel/dist_attr.cc @@ -83,7 +83,7 @@ OperatorDistAttr& OperatorDistAttr::operator=( void OperatorDistAttr::initialize(const OpDesc* op) { if (op == nullptr) return; - for (std::string name : op->InputArgumentNames()) { + for (std::string const& name : op->InputArgumentNames()) { VarDesc* input = op->Block()->FindVarRecursive(name); VLOG(4) << "[OperatorDistAttr create input dist attr] " << name; if (input == nullptr || op->Type() == "create_py_reader") { @@ -92,7 +92,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) { input_dist_attrs_[name] = TensorDistAttr(get_tensor_shape(input)); } } - for (std::string name : op->OutputArgumentNames()) { + for (std::string const& name : op->OutputArgumentNames()) { VarDesc* output = op->Block()->FindVarRecursive(name); VLOG(4) << "[OperatorDistAttr create output dist attr] " << name; if (output == nullptr) { diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc index e8ef88c03032b9..09b4d6a2189b7a 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc @@ -57,7 +57,7 @@ std::unordered_map ShardingMergeForTensors( const bool merge_conflicts) { std::unordered_map axis_to_dim_map; std::unordered_map dim_to_axis_map; - int64_t merge_dim; + int64_t merge_dim = 0; for (auto& pair : tensor_axes_to_dim_pairs) { for (size_t i = 0; i < pair.second.size(); ++i) { diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc index f38a4b2f533b31..5cc6cf4e5e1376 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.cc @@ -34,7 +34,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferForward( input_specs_size)); auto x_shape = input_specs[0].shape(); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto x_dist_attr_src = input_specs[0].dist_attr(); std::vector x_dims_mapping_src = x_dist_attr_src.dims_mapping(); @@ -176,8 +176,8 @@ CrossEntropyWithSoftmaxSPMDRule::InferBackward( const std::vector& output_specs, const paddle::framework::AttributeMap& attrs) { // step0: verify input args based on cross_entropy_with_softmax logic - int64_t ninputs = input_specs.size(); - int64_t noutputs = output_specs.size(); + int64_t ninputs = static_cast(input_specs.size()); + int64_t noutputs = static_cast(output_specs.size()); PADDLE_ENFORCE_EQ( ninputs, 2, @@ -194,7 +194,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferBackward( // step1: build Einsum Notation std::vector x_shape = input_specs[0].shape(); - int64_t x_ndim = x_shape.size(); + int64_t x_ndim = static_cast(x_shape.size()); std::vector label_shape = input_specs[1].shape(); int axis = ExtractAttr("axis", attrs); @@ -205,7 +205,7 @@ CrossEntropyWithSoftmaxSPMDRule::InferBackward( // normalize axis if (axis < 0) { - axis = x_ndim + axis; + axis = static_cast(x_ndim + axis); } std::string alphabet = diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 72db9ee4e7550e..64dce7b4c6b116 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -32,23 +32,57 @@ PD_DECLARE_bool(use_stream_safe_cuda_allocator); namespace paddle { namespace distributed { +static std::mutex g_unfinished_xccl_task_events_mutex; +static std::list> + g_unfinished_xccl_task_events; + ProcessGroupCustom::XCCLTask::XCCLTask(const Place& place, int rank, CommType comm_type, bool sync_op, bool use_calc_stream) : TaskStream(rank, comm_type, sync_op, use_calc_stream), - task_place_(place) { - comm_event_.Init(place); + task_place_(place), + comm_event_(std::make_unique()) { + comm_event_->Init(task_place_); +} + +ProcessGroupCustom::XCCLTask::XCCLTask( + const std::vector& places, + int rank, + CommType CommType, + const std::vector& inputs) + : TaskStream(rank, inputs, CommType), + task_place_(places[0]), + comm_event_(std::make_unique()) { + comm_event_->Init(task_place_); } -ProcessGroupCustom::XCCLTask::~XCCLTask() = default; +ProcessGroupCustom::XCCLTask::~XCCLTask() { + if (!IsCompleted()) { + std::lock_guard lock(g_unfinished_xccl_task_events_mutex); + g_unfinished_xccl_task_events.push_back(std::move(comm_event_)); + } +} -bool ProcessGroupCustom::XCCLTask::IsCompleted() { return comm_event_.Query(); } +bool ProcessGroupCustom::XCCLTask::IsCompleted() { + return comm_event_->Query(); +} void ProcessGroupCustom::XCCLTask::UpdateWaitChain( const phi::DeviceContext& ctx) { - comm_event_.Record( + { + std::lock_guard lock(g_unfinished_xccl_task_events_mutex); + for (auto iter = g_unfinished_xccl_task_events.begin(); + iter != g_unfinished_xccl_task_events.end();) { + if ((*iter)->Query()) { + iter = g_unfinished_xccl_task_events.erase(iter); + } else { + iter++; + } + } + } + comm_event_->Record( reinterpret_cast(ctx).GetStream().get()); } @@ -62,7 +96,7 @@ bool ProcessGroupCustom::XCCLTask::Wait(std::chrono::milliseconds timeout) { const auto* calc_ctx = reinterpret_cast( platform::DeviceContextPool::Instance().Get(task_place_)); - calc_ctx->GetStream()->WaitEvent(&comm_event_); + calc_ctx->GetStream()->WaitEvent(comm_event_.get()); if (IsBlockCPUInWait()) { // If we use the work to do barrier, we should block cpu @@ -590,15 +624,6 @@ std::shared_ptr ProcessGroupCustom::CreateTask( places, rank, comm_type, inputs); } -ProcessGroupCustom::XCCLTask::XCCLTask( - const std::vector& places, - int rank, - CommType CommType, - const std::vector& inputs) - : TaskStream(rank, inputs, CommType), task_place_(places[0]) { - comm_event_.Init(places[0]); -} - // create XCCLManager cache for places_key void ProcessGroupCustom::CreateXCCLManagerCache( const std::string& places_key, const std::vector& places) { @@ -676,7 +701,7 @@ std::shared_ptr ProcessGroupCustom::Collective( { std::lock_guard lock(mutex_); if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { - CreateXCCLManagerCache(key, places); + CreateXCCLEnvCache(places[0], key); } } diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index c60d185c9e4808..13970b2e349a0e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -61,8 +61,8 @@ class ProcessGroupCustom final : public ProcessGroupWithStream { private: bool block_cpu_in_wait_{false}; - phi::event::Event comm_event_; // event on comm stream Place task_place_; + std::unique_ptr comm_event_; // event on comm stream }; public: diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 5482c972fa4d1c..80609b9fd68289 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -571,7 +571,7 @@ void EagerReducer::InitializeGroups( tensor_locator.inside_group_index = inside_group_index++; variable_locators_[var_index] = tensor_locator; } - group.tensor_indices_ = std::move(tensor_indices_); + group.tensor_indices_ = tensor_indices_; groups_.emplace_back(std::move(group)); VLOG(3) << "The Group[" << group_index << "]:" << groups_.back(); @@ -985,6 +985,7 @@ void EagerReducer::ProcessUnusedDenseVars() { opts.reduce_op = ReduceOp::SUM; std::vector reduce_tensors = {global_used_vars_}; std::vector in_out; + in_out.reserve(reduce_tensors.size()); for (auto &t : reduce_tensors) { in_out.push_back(*std::dynamic_pointer_cast(t.impl())); } @@ -1081,6 +1082,7 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, // all_reduce std::vector reduce_tensors = {group->dense_contents_}; std::vector in_out; + in_out.reserve(reduce_tensors.size()); for (auto &t : reduce_tensors) { in_out.push_back(*std::dynamic_pointer_cast(t.impl())); } @@ -1166,6 +1168,7 @@ void EagerReducer::AllReduceSparse(EagerGroup *group, opts.reduce_op = ReduceOp::SUM; std::vector reduce_tensors = {rows_num_tensor}; std::vector in_out; + in_out.reserve(reduce_tensors.size()); for (auto &t : reduce_tensors) { in_out.push_back(*std::dynamic_pointer_cast(t.impl())); } @@ -1214,6 +1217,8 @@ void EagerReducer::AllReduceSparse(EagerGroup *group, std::vector dst_rows_tensors = {dst_rows_tensor}; std::vector in; std::vector out; + in.reserve(src_rows_tensors.size()); + out.reserve(dst_rows_tensors.size()); for (auto &t : src_rows_tensors) { in.push_back(*std::dynamic_pointer_cast(t.impl())); } @@ -1245,6 +1250,8 @@ void EagerReducer::AllReduceSparse(EagerGroup *group, std::vector dst_value_tensors = {dst_value_tensor}; std::vector src_dense; std::vector dst_dense; + src_dense.reserve(src_value_tensors.size()); + dst_dense.reserve(dst_value_tensors.size()); for (auto &t : src_value_tensors) { src_dense.push_back( *std::dynamic_pointer_cast(t.impl())); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 6dc25faa80b4be..82a3514f2791f9 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -121,7 +121,7 @@ void Carrier::CopyParameters( const framework::ProgramDesc& program, const std::vector& inference_root_scope_vars) { std::map inference_root_scope_var_map; - for (auto var_name : inference_root_scope_vars) { + for (auto const& var_name : inference_root_scope_vars) { inference_root_scope_var_map.insert({var_name, 1}); } for (size_t i = 0; i < program.Size(); ++i) { @@ -392,6 +392,7 @@ void Carrier::CreateInterceptors( } } + cores.reserve(microbatch_scopes_.size()); for (framework::Scope* scope : microbatch_scopes_) { cores.push_back(std::make_shared( place_, task_node->program()->Block(0), scope, execution_config)); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 5bf026661d5146..7817b9bc0e9dfe 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -85,7 +85,7 @@ InterceptorMessage ComputeInterceptor::PrepareVarsMsg() { ready_msg.set_message_type(DATA_WITH_VARS); ready_msg.set_scope_idx(cur_scope_id_); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - for (auto iter : node_->vars_to_dtype()) { + for (auto const& iter : node_->vars_to_dtype()) { VarList* vars = ready_msg.add_vars_list(); const auto& var_name = iter.first; vars->set_name(var_name); diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.cc b/paddle/fluid/distributed/fleet_executor/dist_model.cc index 2a6da1b437a1b6..dc89c551fdc711 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.cc +++ b/paddle/fluid/distributed/fleet_executor/dist_model.cc @@ -47,7 +47,7 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, const platform::Place &place) { VLOG(3) << "Loading data from DistModelTensor for " << input_data.name; framework::DDim dims = phi::make_ddim(input_data.shape); - void *input_tensor_ptr; + void *input_tensor_ptr = nullptr; if (input_data.dtype == DistModelDataType::INT64) { input_tensor_ptr = input_tensor->mutable_data(dims, place); } else if (input_data.dtype == DistModelDataType::FLOAT32) { @@ -295,7 +295,7 @@ void DistModel::InsertCommOp(std::string tmp_var_name, << ". The ring id is: " << ring_id << ". The group has: " << nranks << " ranks. Current rank in the group is: " << rank << ". The endpoint is: " << endpoint << ". Peer endpoints are: "; - for (auto ep : peer_endpoints) { + for (const auto &ep : peer_endpoints) { ss << ep << ", "; } VLOG(3) << ss.str(); diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 8daf0636ce890a..99dd6175787e86 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -82,7 +82,7 @@ void PreventVarsDelete( for (const auto& pair : *unused_vars) { const framework::OperatorBase* op = pair.first; std::vector cur_unused = pair.second; - for (auto name : vars_not_gc) { + for (auto const& name : vars_not_gc) { auto iter = std::find(cur_unused.begin(), cur_unused.end(), name); if (iter != cur_unused.end()) { VLOG(3) << "Removing var: [" << name @@ -165,7 +165,7 @@ void FleetExecutor::Init( while_block_vars = GetUnusedVarsAfterWhile( program_desc, task_node, inference_root_scope_vars); VLOG(3) << "Vars will be gced after while op"; - for (auto var : while_block_vars) { + for (auto const& var : while_block_vars) { VLOG(3) << var; } task_node->SetWhileBlockVars(while_block_vars); diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index 96210c16dd9ef5..ab155de79feedd 100755 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -75,3 +75,7 @@ cc_library( generated_op autograd_meta hook_utils) +# FIXME(Aurelius84): It seems utils library is depended in cycle, but +# CMake only find it twice to deal cycle depend problem. If it is still +# not found, ld error will be raised. +set_target_properties(utils PROPERTIES LINK_INTERFACE_MULTIPLICITY 3) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index eb5cd7fb1242da..5fb5c99f1c09f1 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -410,7 +410,7 @@ static std::pair GetAttrType( ret = "std::vector"; if (is_arg) ret += "&"; val += "{"; - for (auto x : PADDLE_GET_CONST(std::vector, attr)) { + for (auto const& x : PADDLE_GET_CONST(std::vector, attr)) { val += "\"" + x + "\"" + ","; } if (val.size() > 1) val.pop_back(); @@ -1238,7 +1238,7 @@ static std::string GenerateGradNodeCreationContent( bool found_target_name = false; for (const auto& iter : op_base_infos) { const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap(); - for (auto iter : grad_outs_slot_map) { + for (auto const& iter : grad_outs_slot_map) { if ((!found_target_name) && (input_name == iter.second)) { const char* SET_GRAD_OUT_META_TEMPLATE = " grad_node->SetGradOutMeta(%s, %d);\n"; @@ -1256,7 +1256,7 @@ static std::string GenerateGradNodeCreationContent( bool found_target_name = false; for (const auto& iter : op_base_infos) { const auto& grad_outs_slot_map = iter.GetGradOutsSlotnameMap(); - for (auto iter : grad_outs_slot_map) { + for (auto const& iter : grad_outs_slot_map) { if ((!found_target_name) && (input_name == iter.second)) { const char* SET_GRAD_OUT_META_TEMPLATE = " grad_node->SetGradOutMeta(%s, %d);\n"; @@ -2142,7 +2142,7 @@ static std::string GenerateSingleOpBase( // [Generation] Get Full Zero std::string fill_zero_str = ""; if (ops_to_fill_zero_for_empty_grads.count(fwd_op_type)) { - for (auto iter : grad_ins) { + for (auto const& iter : grad_ins) { const std::string& grad_input_name = iter.first; if (grad_ins_grad_slotname_map.count(grad_input_name)) { size_t fwd_output_position = fwd_outputs_name_pos_map.at( @@ -2189,7 +2189,7 @@ static std::string GenerateSingleOpBase( "backward_inplace_tensor" + std::to_string(*outs_size); bool process_backward_inplace = false; std::string ins_contents_str = ""; - for (auto iter : grad_ins) { + for (auto const& iter : grad_ins) { const std::string& grad_input_name = iter.first; if (grad_ins_fwd_slotname_map.count(grad_input_name)) { @@ -2293,7 +2293,7 @@ static std::string GenerateSingleOpBase( paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str); generated_grad_function_body += ins_map_str; - for (auto iter : grad_ins) { + for (auto const& iter : grad_ins) { const std::string& grad_input_name = iter.first; if (grad_ins_fwd_slotname_map.count(grad_input_name)) { @@ -2335,7 +2335,7 @@ static std::string GenerateSingleOpBase( VLOG(6) << "Generated Ins Map"; // [Generation] Get Outs Map std::string outs_contents_str = ""; - for (auto iter : grad_outs) { + for (auto const& iter : grad_outs) { const std::string& grad_output_name = iter.first; if (grad_outs_slotname_map.count(grad_output_name)) { @@ -2440,7 +2440,7 @@ static std::string GenerateSingleOpBase( generated_grad_function_body += outs_map_str; generated_grad_function_body += outs_contents_str; generated_grad_function_body += "\n"; - for (auto iter : grad_outs) { + for (auto const& iter : grad_outs) { const std::string& grad_output_name = iter.first; if (grad_outs_slotname_map.count(grad_output_name)) { @@ -2498,7 +2498,7 @@ static std::string GenerateSingleOpBase( "%s[\"%s\"][0]);\n" " };\n"; std::string backward_inplace_map_str = ""; - for (auto iter : backward_inplace_map) { + for (auto const& iter : backward_inplace_map) { std::string backward_inplace_input_name = iter.first; std::string backward_inplace_output_name = iter.second; backward_inplace_map_str += paddle::string::Sprintf( @@ -2553,7 +2553,7 @@ static std::string GenerateSingleOpBase( // [Generation] Get Return std::string outputs_str = ""; size_t num_appended_outputs = 0; - for (auto iter : grad_outs) { + for (auto const& iter : grad_outs) { const std::string& grad_out_name = iter.first; const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); @@ -2594,7 +2594,7 @@ static std::string GenerateSingleOpBase( /* Handle Special Case: "PullSparseOp", etc For returns, append "GradOut" to the very end of return list. */ - for (auto iter : grad_outs) { + for (auto const& iter : grad_outs) { const std::string& grad_out_name = iter.first; const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 685ff8d4b72975..60e02a29d72b40 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -85,7 +85,7 @@ void EnforceGradNodeHasInput(GradNodeBase* node) { void DuplicateCheck(const std::vector& inputs, bool is_input) { std::unordered_set visisted_ins; std::string msg = is_input ? "inputs" : "outputs"; - for (auto in : inputs) { + for (auto const& in : inputs) { AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(in); PADDLE_ENFORCE_EQ( visisted_ins.count(auto_grad_meta), @@ -378,9 +378,9 @@ std::vector RunBackward( auto add_next_node_func = [&node_in_degree_map, &queue](GradNodeBase* next_node) { if (dynamic_cast(next_node)) { - queue.push_front(std::move(next_node)); + queue.push_front(next_node); } else { - queue.push_back(std::move(next_node)); + queue.push_back(next_node); } }; if (node_in_degree_map[next_node] == 0) { diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index af914d3ae3c791..5643c0e69391f0 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -154,7 +154,7 @@ static void ConstructFwdAndBwdMap( << "'s No." << j << " attrs: " << attrs_names[j] << " related to No." << i << " grad_attrs: " << grad_attrs_names[i]; - in_out_map[op_type][1][4][j] = i; + in_out_map[op_type][1][4][j] = i; // NOLINT } } } @@ -190,12 +190,12 @@ RunCustomOpNode::operator()(paddle::small_vector, } } - for (auto it : fwd_outs) { + for (auto it : fwd_outs) { // NOLINT VLOG(7) << "Insert fwd_outs to grad_inputs: " << it.first; tmp_ins[it.first] = RunCustomOpNode::Recover(&(it.second)); } - for (auto it : fwd_ins) { + for (auto it : fwd_ins) { // NOLINT // NOTE(HongyuJia): returned tensor maybe un-defined tensor when inputs // optional VLOG(7) << "Insert fwd_ins to grad_inputs: " << it.first; @@ -406,12 +406,12 @@ RunCustomOpDoubleGradNode::operator()( } } - for (auto it : fwd_outs) { + for (auto it : fwd_outs) { // NOLINT VLOG(7) << "Insert fwd_outs to grad_inputs: " << it.first; tmp_ins[it.first] = RunCustomOpDoubleGradNode::Recover(&(it.second)); } - for (auto it : fwd_ins) { + for (auto it : fwd_ins) { // NOLINT VLOG(7) << "Insert fwd_ins to grad_inputs: " << it.first; tmp_ins[it.first] = RunCustomOpDoubleGradNode::Recover(&(it.second)); } diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 3ceeda65c8e611..2619e706cfa134 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -277,6 +277,15 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in, meta.SetTensorMeta(dense_tensor.meta()); meta.SetPlace(fwd_in.place()); // Set DistAttr + PADDLE_ENFORCE_EQ(dist_tensor->defined(), + true, + phi::errors::InvalidArgument( + "The forward input DistTensor is not defined.")); + PADDLE_ENFORCE_NE( + dist_tensor->dist_attr().empty(), + true, + phi::errors::InvalidArgument( + "The forward input DistTensor's dist attr is empty.")); meta.SetDistAttr(dist_tensor->dist_attr()); SetIsRunAutoParallel(true); } else { diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index 5051dd39d9819f..34469f875198b7 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -89,11 +89,11 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id, auto init_grad = paddle::experimental::full(t.shape(), 1, t.dtype(), t.place()); auto global_dense_t = - static_cast(init_grad.impl().get()); + std::static_pointer_cast(init_grad.impl()); auto dist_t = static_cast(t.impl().get()); init_grad.set_impl(std::make_shared( - *global_dense_t, dist_t->dist_attr())); + global_dense_t, dist_t->dist_attr())); buffer_[slot_id][rank] = init_grad; } else { PADDLE_THROW(paddle::platform::errors::Fatal( diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 0ac940ab496e29..83e4424a212514 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -32,6 +32,7 @@ #include "paddle/pir/core/value.h" PHI_DECLARE_bool(enable_new_ir_in_executor); +PHI_DECLARE_bool(print_ir); namespace details { using Tensor = paddle::Tensor; @@ -191,6 +192,12 @@ static auto GetNameFromValue(const ::pir::Block *block, .dyn_cast() .AsString(); value2name[op->operand(0).source()] = name; + } else if (!is_input && op->name() == "builtin.shadow_output") { + name = op->attributes() + .at("output_name") + .dyn_cast() + .AsString(); + value2name[op->operand(0).source()] = name; } else if (is_input && op->name() == "builtin.get_parameter") { name = op->attributes() .at("parameter_name") @@ -463,12 +470,13 @@ inline void NewIRRunProgramAPI( auto *backward_program = backward_global_block->GetParentOp()->GetParentProgram(); - if (VLOG_IS_ON(4)) { + if (FLAGS_print_ir) { std::ostringstream print_stream; + print_stream << "ForwardProgram is :\n"; forward_program->Print(print_stream); - print_stream << "\n"; + print_stream << "BackwardProgram is:\n"; backward_program->Print(print_stream); - VLOG(4) << print_stream.str(); + std::cout << "Program (fwd | bwd): \n" << print_stream.str() << std::endl; } VLOG(10) << is_test << program_id; diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index eb2dcca4d3b314..28ca8636720dcd 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -28,6 +28,38 @@ #include "paddle/fluid/framework/variable.h" namespace egr { + +void SetGradOutputDistAttrIter::visit_element(paddle::Tensor* element, + const GradSlotMeta& meta) { + if (element == nullptr) { + VLOG(4) << "The input element is nullptr when calling " + "SetGradOutputDistAttrIter."; + return; + } + // Here the element is empty or defined DistTensor + VLOG(4) << "The input element is set DistTensor impl when calling " + "SetGradOutputDistAttrIter."; + element->set_impl(std::make_shared( + phi::DDim(), meta.DistAttr())); +} + +void SetGradOutputDistAttrIter::visit(paddle::Tensor* element) { + if (!out_meta_[out_indexes_[cur_pos_]].empty()) { + visit_element(element, out_meta_[out_indexes_[cur_pos_]][0]); + } + cur_pos_++; +} + +void SetGradOutputDistAttrIter::visit( + const std::vector& elements) { + if (!out_meta_[out_indexes_[cur_pos_]].empty()) { + for (size_t i = 0; i < elements.size(); ++i) { + visit_element(elements.at(i), out_meta_[out_indexes_[cur_pos_]][i]); + } + } + cur_pos_++; +} + /** * Implementation of Eager Utils. **/ diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index c1fe208c4c72a8..8dd950be0cbe22 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -97,41 +97,9 @@ class SetGradOutputDistAttrIter : public IterHelper { : out_meta_(out_meta), out_indexes_{out_indexes} {} private: - void visit_element(paddle::Tensor* element, const GradSlotMeta& meta) { - if (element == nullptr) { - return; - } - if (meta.DistAttr().empty()) { - return; - } - if (element->defined()) { - if (element->is_dist_tensor()) { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupport set defined dist tensor now.")); - } else { - // Only deal with dist tensor here - return; - } - } else { - element->set_impl(std::make_shared( - phi::DDim(), meta.DistAttr())); - } - } - void visit(paddle::Tensor* element) override { - if (!out_meta_[out_indexes_[cur_pos_]].empty()) { - visit_element(element, out_meta_[out_indexes_[cur_pos_]][0]); - } - cur_pos_++; - } - - void visit(const std::vector& elements) override { - if (!out_meta_[out_indexes_[cur_pos_]].empty()) { - for (size_t i = 0; i < elements.size(); ++i) { - visit_element(elements.at(i), out_meta_[out_indexes_[cur_pos_]][i]); - } - } - cur_pos_++; - } + void visit_element(paddle::Tensor* element, const GradSlotMeta& meta); + void visit(paddle::Tensor* element) override; + void visit(const std::vector& elements) override; const paddle::small_vector, kSlotSmallVectorSize>& out_meta_; diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 8814935e3fceb5..81075e0c5fb5bd 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -110,7 +110,7 @@ static void RunKernelFunc( // tensor here. custom_vec_in.emplace_back(paddle::Tensor()); } - kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in)); + kernel_ctx.EmplaceBackInputs(custom_vec_in); } else { // inputs Tensor if (ctx.HasInput(in_name)) { // general Tensor inputs auto* x = ctx.Input(in_name); @@ -231,7 +231,7 @@ static void RunKernelFunc( custom_t.set_impl(std::make_shared(*out)); custom_vec_out.emplace_back(custom_t); } - kernel_ctx.EmplaceBackOutputs(std::move(custom_vec_out)); + kernel_ctx.EmplaceBackOutputs(custom_vec_out); } else { // handle inplace optional outputs = None case if (!ctx.HasOutput(out_name)) { @@ -318,7 +318,7 @@ static void RunKernelFunc( } } } catch (platform::EnforceNotMet& exception) { - throw std::move(exception); + throw exception; } catch (std::exception& ex) { PADDLE_THROW(platform::errors::External("%s", ex.what())); } catch (...) { diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 902cd0f39369a2..19c5196d2f933a 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -1808,7 +1808,7 @@ int PaddleBoxDataFeed::Next() { output_pv_channel_->Get(pv_instance); pv_vec.push_back(pv_instance); ++index; - consume_pv_channel_->Put(std::move(pv_instance)); + consume_pv_channel_->Put(pv_instance); } this->batch_size_ = index; VLOG(3) << "pv_batch_size_=" << this->batch_size_ @@ -2448,9 +2448,9 @@ bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line, } // parse_logkey std::string log_key = std::string(str + pos, len); - uint64_t search_id; - uint32_t cmatch; - uint32_t rank; + uint64_t search_id = 0; + uint32_t cmatch = 0; + uint32_t rank = 0; parser_log_key(log_key, &search_id, &cmatch, &rank); rec->ins_id_ = log_key; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 18ed0bb6e901aa..6c66188567717f 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -138,7 +138,7 @@ std::vector DatasetImpl::GetSlots() { } } std::cout << "dataset use slots: "; - for (auto s : use_slots_) { + for (auto const& s : use_slots_) { std::cout << s << " | "; } std::cout << " end " << std::endl; @@ -216,7 +216,7 @@ template std::vector DatasetImpl::GetReaders() { std::vector ret; ret.reserve(readers_.size()); - for (auto i : readers_) { + for (auto const& i : readers_) { ret.push_back(i.get()); } return ret; @@ -1533,7 +1533,7 @@ void MultiSlotDataset::MergeByInsId() { break; } local_uint64.insert(slot); - rec.uint64_feasigns_.push_back(std::move(feature)); + rec.uint64_feasigns_.push_back(feature); } if (has_conflict_slot) { break; @@ -1550,7 +1550,7 @@ void MultiSlotDataset::MergeByInsId() { break; } local_float.insert(slot); - rec.float_feasigns_.push_back(std::move(feature)); + rec.float_feasigns_.push_back(feature); } if (has_conflict_slot) { break; diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 0615df45b76793..eec3439cf04316 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -52,11 +52,11 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const std::vector &local_exec_scopes, const std::vector &places, std::vector graphs) - : strategy_(std::move(strategy)), - local_scopes_(std::move(local_scopes)), + : strategy_(strategy), + local_scopes_(local_scopes), local_exec_scopes_(local_exec_scopes), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), - places_(std::move(places)), + places_(places), graphs_(std::move(graphs)) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index b71c476a2c95e2..27be4b77176350 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -113,7 +113,7 @@ void FetchOpHandle::WaitAndMergeCPUFetchVars() const { } } else { auto &val = PADDLE_GET(FetchUnmergedList, *data_); - val.at(offset_) = std::move(tensors_); + val.at(offset_) = tensors_; } } diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 79b43b1b501db6..0aae1ce6b60d73 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -45,7 +45,7 @@ void GatherOpHandle::RunImpl() { in_var_handles.size(), places_.size())); - VarHandle *out_var_handle; + VarHandle *out_var_handle = nullptr; { auto out_var_handles = DynamicCast(this->Outputs()); PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/framework/details/multi_devices_helper.cc b/paddle/fluid/framework/details/multi_devices_helper.cc index 4849ca34e3e956..d2379c2c49a19d 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.cc +++ b/paddle/fluid/framework/details/multi_devices_helper.cc @@ -176,7 +176,7 @@ static bool IsDataParallelInferenceGraphImpl( } bool IsDataParallelInferenceGraph(const ir::Graph &graph) { - size_t place_num; + size_t place_num = 0; std::unordered_map op_to_dev_idx; return IsDataParallelInferenceGraphImpl(graph, &op_to_dev_idx, &place_num); } @@ -196,7 +196,7 @@ bool IsDataParallelInferenceGraph(const ir::Graph &graph) { */ std::vector> TrySeparateToMultipleSingleDeviceGraphs( ir::Graph *graph) { - size_t place_num; + size_t place_num = 0; std::unordered_map op_to_dev_idx; if (!IsDataParallelInferenceGraphImpl(*graph, &op_to_dev_idx, &place_num)) { return {}; diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 42f97e975ed3c2..b917c161193fbe 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -105,8 +105,8 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( const std::vector &local_exec_scopes, const std::vector &places, std::vector> graphs) - : strategy_(std::move(strategy)), - local_scopes_(std::move(local_scopes)), + : strategy_(strategy), + local_scopes_(local_scopes), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(places), graphs_(std::move(graphs)), @@ -297,6 +297,7 @@ FetchResultType ParallelSSAGraphExecutor::Run( for (size_t i = 0; i < lodtensorarray_ptrs[0]->size(); ++i) { phi::DenseTensor var; std::vector ptrs; + ptrs.reserve(lodtensor_ptrs.size()); for (auto &lodtensorarray_ptr : lodtensorarray_ptrs) { ptrs.push_back(&(lodtensorarray_ptr->at(i))); } diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 7acf425fd77f30..fe43126ca8abe4 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -63,7 +63,7 @@ void ReduceOpHandle::RunImpl() { in_var_handles.size(), places_.size())); - VarHandle *out_var_handle; + VarHandle *out_var_handle = nullptr; { auto out_var_handles = DynamicCast(outputs_); diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 7ee7fa82250a99..9d275b0fd4c2e1 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -36,7 +36,7 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( std::vector var_infos, std::vector places, std::unique_ptr &&underlying_executor) - : strategy_(std::move(strategy)), + : strategy_(strategy), underlying_executor_(std::move(underlying_executor)), local_scopes_(std::move(local_scopes)), local_exec_scopes_(std::move(local_exec_scopes)), diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ce3fe004c40bb8..0397f87f6649ef 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -128,7 +128,7 @@ inline FetchResultType ThreadedSSAGraphExecutor::RunImpl( run_all_ops(ready_ops); // 2. Find ready variable - bool timeout; + bool timeout = false; auto cur_ready_vars = ready_vars->PopAll(1, &timeout); if (timeout) { for (auto &run_op_future : run_op_futures_) { diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index e69a25bb32781a..c9bd59f912d7a3 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -135,7 +135,7 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) { static_cast(table_idx))); TableParameter table; - for (auto i : param_.sparse_table()) { + for (auto const& i : param_.sparse_table()) { if (i.table_id() == table_id) { table = i; break; @@ -191,7 +191,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) { static_cast(table_idx))); TableParameter table; - for (auto i : param_.sparse_table()) { + for (auto const& i : param_.sparse_table()) { if (i.table_id() == table_id) { table = i; break; @@ -485,7 +485,7 @@ void DownpourWorker::TrainFilesWithProfiler() { double push_sparse_time = 0.0; double push_dense_time = 0.0; double copy_table_time = 0.0; - int cur_batch; + int cur_batch = 0; int batch_cnt = 0; uint64_t total_inst = 0; timeline.Start(); @@ -513,7 +513,7 @@ void DownpourWorker::TrainFilesWithProfiler() { uint64_t tid = static_cast( param_.program_config(0).pull_sparse_table_id(i)); TableParameter table; - for (auto j : param_.sparse_table()) { + for (auto const& j : param_.sparse_table()) { if (j.table_id() == tid) { table = j; break; @@ -599,7 +599,7 @@ void DownpourWorker::TrainFilesWithProfiler() { uint64_t tid = static_cast( param_.program_config(0).push_sparse_table_id(i)); TableParameter table; - for (auto i : param_.sparse_table()) { + for (auto const& i : param_.sparse_table()) { if (i.table_id() == tid) { table = i; break; @@ -804,7 +804,7 @@ void DownpourWorker::TrainFiles() { platform::SetNumThreads(1); device_reader_->Start(); int batch_cnt = 0; - int cur_batch; + int cur_batch = 0; while ((cur_batch = device_reader_->Next()) > 0) { if (copy_table_config_.need_copy()) { if (batch_cnt % copy_table_config_.batch_num() == 0) { @@ -819,7 +819,7 @@ void DownpourWorker::TrainFiles() { uint64_t tid = static_cast( param_.program_config(0).pull_sparse_table_id(i)); TableParameter table; - for (auto j : param_.sparse_table()) { + for (auto const& j : param_.sparse_table()) { if (j.table_id() == tid) { table = j; break; @@ -936,7 +936,7 @@ void DownpourWorker::TrainFiles() { uint64_t tid = static_cast( param_.program_config(0).push_sparse_table_id(i)); TableParameter table; - for (auto i : param_.sparse_table()) { + for (auto const& i : param_.sparse_table()) { if (i.table_id() == tid) { table = i; break; diff --git a/paddle/fluid/framework/downpour_worker_opt.cc b/paddle/fluid/framework/downpour_worker_opt.cc index 68c774965aeabf..d7d8a7ff883cdd 100644 --- a/paddle/fluid/framework/downpour_worker_opt.cc +++ b/paddle/fluid/framework/downpour_worker_opt.cc @@ -262,7 +262,7 @@ void DownpourWorkerOpt::CreateThreadOperatorsWithRerank( uint64_t tid = static_cast(param_.program_config(0).pull_sparse_table_id(i)); TableParameter table; - for (auto j : param_.sparse_table()) { + for (auto const& j : param_.sparse_table()) { if (j.table_id() == tid) { table = j; break; @@ -307,7 +307,7 @@ void DownpourWorkerOpt::TrainFiles() { platform::SetNumThreads(1); device_reader_->Start(); int batch_cnt = 0; - int cur_batch; + int cur_batch = 0; std::future pull_async_status; std::string async_wait_name = ""; for (int i = 0; i < param_.program_config(0).pull_sparse_table_id_size(); @@ -315,7 +315,7 @@ void DownpourWorkerOpt::TrainFiles() { uint64_t tid = static_cast(param_.program_config(0).pull_sparse_table_id(i)); TableParameter table; - for (auto j : param_.sparse_table()) { + for (auto const& j : param_.sparse_table()) { if (j.table_id() == tid) { table = j; break; @@ -344,7 +344,7 @@ void DownpourWorkerOpt::TrainFiles() { uint64_t tid = static_cast( param_.program_config(0).pull_sparse_table_id(i)); TableParameter table; - for (auto j : param_.sparse_table()) { + for (auto const& j : param_.sparse_table()) { if (j.table_id() == tid) { table = j; break; @@ -455,7 +455,7 @@ void DownpourWorkerOpt::TrainFiles() { uint64_t tid = static_cast( param_.program_config(0).push_sparse_table_id(i)); TableParameter table; - for (auto i : param_.sparse_table()) { + for (auto const& i : param_.sparse_table()) { if (i.table_id() == tid) { table = i; break; diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 5613a8dbf155e0..2e1eb0a58fe5a5 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -353,6 +353,10 @@ std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( return core; } +bool TensorSortHelper(const paddle::Tensor &t1, const paddle::Tensor &t2) { + return t1.name() < t2.name(); +} + std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc *forward_global_block, const paddle::framework::BlockDesc *backward_global_block, @@ -398,7 +402,9 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram( } std::set input_param_names; - for (auto ¶m : params) { + auto sorted_params = params; + std::sort(sorted_params.begin(), sorted_params.end(), TensorSortHelper); + for (auto ¶m : sorted_params) { auto &name = param.name(); auto p = param.place().GetType(); @@ -515,6 +521,8 @@ std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( for (auto &t : x_grad) { param_grad_names.push_back(t->name()); } + + std::sort(param_grad_names.begin(), param_grad_names.end()); for (auto &name : param_grad_names) { if (name == "@EMPTY@") { continue; diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 08d681ae6411fe..3d32781216402c 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -126,7 +126,7 @@ void HogwildWorker::SetZero(phi::DenseTensor *tensor, void HogwildWorker::BindingDataFeedMemory() { const std::vector &input_feed = device_reader_->GetUseSlotAlias(); - for (auto name : input_feed) { + for (auto const &name : input_feed) { device_reader_->AddFeedVar(thread_scope_->FindVar(name), name); } } @@ -239,7 +239,7 @@ void HogwildWorker::TrainFilesWithProfiler() { platform::Timer timeline; double total_time = 0.0; double read_time = 0.0; - int cur_batch; + int cur_batch = 0; int batch_cnt = 0; if (thread_id_ == 0) { quit_flag_.store(false); @@ -372,7 +372,7 @@ void HogwildWorker::TrainFiles() { int total_batch_num = 0; // how to accumulate fetched values here device_reader_->Start(); - int cur_batch; + int cur_batch = 0; int batch_cnt = 0; if (thread_id_ == 0) { quit_flag_.store(false); @@ -471,7 +471,7 @@ void HogwildWorker::PrintFetchVars() { } if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) { - time_t curtime; + time_t curtime = 0; time(&curtime); std::array mbstr; std::strftime(mbstr.data(), diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 38cc88f7ec9360..4c41bc27f1730e 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -505,8 +505,7 @@ CompatInferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { result.emplace_back(in.initialized() ? &in : nullptr); } - return paddle::optional>( - std::move(result)); + return paddle::optional>(result); } return paddle::none; } @@ -637,11 +636,11 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, if (ctx->IsRuntime()) { Variable* var = PADDLE_GET_CONST(Variable*, infershape_input[0]); infer_meta_context.EmplaceBackAttr( - std::move(framework::MakePhiScalarFromVar(*var))); + framework::MakePhiScalarFromVar(*var)); } else { phi::Scalar tensor_scalar(-1); tensor_scalar.SetFromTensor(true); - infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar)); + infer_meta_context.EmplaceBackAttr(tensor_scalar); } } else { PADDLE_THROW(platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/init_default_kernel_signature_map.h b/paddle/fluid/framework/init_default_kernel_signature_map.h new file mode 100644 index 00000000000000..a6b6400dd19f59 --- /dev/null +++ b/paddle/fluid/framework/init_default_kernel_signature_map.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/utils/test_macros.h" + +// The implementation of InitDefaultKernelSignatureMap is in phi_utils.cc +namespace paddle { +namespace framework { +TEST_API void InitDefaultKernelSignatureMap(); +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/io/crypto/cipher.cc b/paddle/fluid/framework/io/crypto/cipher.cc index 2001e8a416a1a1..03e0cb4d0eb273 100644 --- a/paddle/fluid/framework/io/crypto/cipher.cc +++ b/paddle/fluid/framework/io/crypto/cipher.cc @@ -24,8 +24,8 @@ namespace framework { std::shared_ptr CipherFactory::CreateCipher( const std::string& config_file) { std::string cipher_name; - int iv_size; - int tag_size; + int iv_size = 0; + int tag_size = 0; std::unordered_map config; if (!config_file.empty()) { config = CipherUtils::LoadConfig(config_file); diff --git a/paddle/fluid/framework/io/crypto/cipher_utils.cc b/paddle/fluid/framework/io/crypto/cipher_utils.cc index c10da1ce6706cf..42d6223b729af5 100644 --- a/paddle/fluid/framework/io/crypto/cipher_utils.cc +++ b/paddle/fluid/framework/io/crypto/cipher_utils.cc @@ -72,7 +72,7 @@ std::unordered_map CipherUtils::LoadConfig( "make sure input filename is available.", config_file)); std::unordered_map ret; - char c; + char c = 0; std::string line; std::istringstream iss; while (std::getline(fin, line)) { diff --git a/paddle/fluid/framework/io/crypto/cipher_utils_test.cc b/paddle/fluid/framework/io/crypto/cipher_utils_test.cc index 356c919cbcbe8c..ee4453bcaab676 100644 --- a/paddle/fluid/framework/io/crypto/cipher_utils_test.cc +++ b/paddle/fluid/framework/io/crypto/cipher_utils_test.cc @@ -46,19 +46,19 @@ TEST(CipherUtils, load_config) { EXPECT_TRUE(CipherUtils::GetValue(config, "key_str", &out_str)); EXPECT_EQ(out_str, std::string("ciphername")); - int out_int; + int out_int = 0; EXPECT_TRUE(CipherUtils::GetValue(config, "key_int", &out_int)); EXPECT_EQ(out_int, 1); - bool out_bool; + bool out_bool = false; EXPECT_TRUE(CipherUtils::GetValue(config, "key_bool", &out_bool)); EXPECT_EQ(out_bool, true); - bool out_bool1; + bool out_bool1 = false; EXPECT_TRUE(CipherUtils::GetValue(config, "key_bool1", &out_bool1)); EXPECT_EQ(out_bool1, false); - bool out_bool2; + bool out_bool2 = false; EXPECT_TRUE(CipherUtils::GetValue(config, "key_bool2", &out_bool2)); EXPECT_EQ(out_bool2, false); } diff --git a/paddle/fluid/framework/ir/add_support_int8_pass.cc b/paddle/fluid/framework/ir/add_support_int8_pass.cc index 21b45d1b1fa388..5dedfe59f6900a 100644 --- a/paddle/fluid/framework/ir/add_support_int8_pass.cc +++ b/paddle/fluid/framework/ir/add_support_int8_pass.cc @@ -61,8 +61,8 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { // scale for one output for (auto out_node : quant_op->outputs) { for (auto out_op_node : out_node->outputs) { - for (auto name : out_op_node->Op()->InputNames()) { - for (auto input_name : out_op_node->Op()->Input(name)) { + for (auto const& name : out_op_node->Op()->InputNames()) { + for (auto const& input_name : out_op_node->Op()->Input(name)) { if (out_op_node->Op()->HasAttr("Input_scale_" + input_name)) { for (size_t i = 0; i < quanted_op_desc->OutputNames().size(); i++) { diff --git a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc index b4307f5ce758d4..44cb004fec1729 100644 --- a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc +++ b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc @@ -189,7 +189,7 @@ class CoalesceGradTensorPass : public ir::Pass { const { if (params_grads.empty()) return true; auto dtype = GetDtypeOfVar(vars_info, params_grads.front().second); - for (auto p_g : params_grads) { + for (auto const &p_g : params_grads) { auto next_dtype = GetDtypeOfVar(vars_info, p_g.second); if (next_dtype != dtype) { return false; diff --git a/paddle/fluid/framework/ir/constant_folding_pass.cc b/paddle/fluid/framework/ir/constant_folding_pass.cc index b32b0bb04b94cf..1bcec1e5a898c1 100644 --- a/paddle/fluid/framework/ir/constant_folding_pass.cc +++ b/paddle/fluid/framework/ir/constant_folding_pass.cc @@ -93,7 +93,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { map[out_node->Name()] = 0; } // Forbid other node in graph having the same name with nodes in map - for (auto iter : map) { + for (auto const &iter : map) { for (auto node : graph->Nodes()) { if (node->IsVar() && node->Name() == iter.first) { map[node->Name()]++; diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 6b21bfa5defc9d..aa15b2696d7a12 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -371,8 +371,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { bool mkldnn_with_bias = is_mkldnn && has_bias; // Create eltwise_y (conv bias) variable - phi::DenseTensor* eltwise_y_in_tensor; - Node* eltwise_y_in_node; + phi::DenseTensor* eltwise_y_in_tensor = nullptr; + Node* eltwise_y_in_node = nullptr; if (!mkldnn_with_bias) { VarDesc eltwise_y_in_desc( patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in")); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc index 7cd069eea91a81..2f420bc858e37f 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc @@ -59,7 +59,7 @@ void TestMain(const std::string& conv_type) { auto* in = layers.data("in", {1, 3, 20, 20}); auto* filters = layers.data("filters", {3, 3, 2, 2}, true); auto* bias_0 = layers.data("bias_0", {3}, true); - VarDesc* conv_out; + VarDesc* conv_out = nullptr; if (conv_type == "conv_transpose") { conv_out = layers.conv2d_transpose(in, filters, bias_0); } else { diff --git a/paddle/fluid/framework/ir/delete_cast_op_pass.cc b/paddle/fluid/framework/ir/delete_cast_op_pass.cc index 6d4224982f79b4..59fd42241e0d4b 100644 --- a/paddle/fluid/framework/ir/delete_cast_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_cast_op_pass.cc @@ -84,7 +84,7 @@ static std::vector FindOpNodeWithInputName( if (!node->IsOp()) continue; auto inputs = node->Op()->Inputs(); bool find_input = false; - for (auto input : inputs) { + for (auto const& input : inputs) { auto input_names = input.second; if (std::count(input_names.begin(), input_names.end(), input_name) > 0) { find_input = true; @@ -103,7 +103,7 @@ static std::vector FindOpNodeWithOutputName( if (!node->IsOp()) continue; auto outputs = node->Op()->Outputs(); bool find_output = false; - for (auto output : outputs) { + for (auto const& output : outputs) { auto output_names = output.second; if (std::count(output_names.begin(), output_names.end(), output_name) > 0) { diff --git a/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc b/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc index 11d1339f35d249..17f0c642a60d18 100644 --- a/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc +++ b/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc @@ -53,6 +53,7 @@ VarDesc* AddWriteToArray(BlockDesc* block, OpDesc* op = block->AppendOp(); op->SetType("write_to_array"); std::vector x_names; + x_names.reserve(x.size()); for (auto k : x) { x_names.push_back(k->Name()); } diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index cb6a6e1d5d9dc4..286f7f08cdfc97 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -122,7 +122,7 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument( "Input scale tensor's place should be CPU.")); - float input_scale; + float input_scale = NAN; if (input_scale_tensor.dtype() == phi::DataType::FLOAT32) { const float* input_scale_data = input_scale_tensor.data(); input_scale = input_scale_data[0]; diff --git a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc index ea0de9759f3a01..1917fb56f13ae9 100644 --- a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc +++ b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc @@ -36,7 +36,7 @@ namespace ir { bool HasOutVarName(Node* op_node, std::string name) { auto* op_desc = op_node->Op(); auto outputs = op_desc->Outputs(); - for (auto iter : outputs) { + for (auto const& iter : outputs) { auto out_names = iter.second; if (std::count(out_names.begin(), out_names.end(), name) > 0) { return true; @@ -155,7 +155,7 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps( } } - for (auto iter : ops_map) { + for (auto const& iter : ops_map) { auto ops = iter.second; auto* first_op_out = ops[0]->outputs[0]; auto first_op_out_name = first_op_out->Name(); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index 0b09d1b30f40af..cf5c9a2c94cf9b 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -45,7 +45,8 @@ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { if (n->IsOp()) { auto* op = n->Op(); if (op->Type() == "dequantize_linear") { - Node *weight_var_node, *calcu_op_node, *while_op_node; + Node *weight_var_node = nullptr, *calcu_op_node = nullptr, + *while_op_node = nullptr; Node *dequantized_weight_var_node = nullptr, *scale_var_node = nullptr; // 1. Judge whether for dequant weight and find // weight_var_node/scale_var_node diff --git a/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc b/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc index 071e85b3a3303c..7e6eda7ac139e1 100644 --- a/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc +++ b/paddle/fluid/framework/ir/fuse_adamw_op_pass.cc @@ -24,6 +24,7 @@ namespace ir { std::vector GetNodeNames(const std::vector &node_vector) { std::vector out_vector; + out_vector.reserve(node_vector.size()); for (auto i : node_vector) { out_vector.emplace_back(i->Name()); } diff --git a/paddle/fluid/framework/ir/fuse_bn_act_pass.cc b/paddle/fluid/framework/ir/fuse_bn_act_pass.cc index 2b3f64927f212b..048b33a649f94d 100644 --- a/paddle/fluid/framework/ir/fuse_bn_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_bn_act_pass.cc @@ -326,8 +326,8 @@ void FuseBatchNormActPass::ReLinkNodes(Graph *graph, IR_OP_VAR_LINK(fused_op, out); } - nodes2delete.insert(std::move(op_1)); - nodes2delete.insert(std::move(op_2)); + nodes2delete.insert(op_1); + nodes2delete.insert(op_2); GraphSafeRemoveNodes(graph, nodes2delete); } diff --git a/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc index ceae7bb4beb9ba..2a24c5476a5010 100644 --- a/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc @@ -322,9 +322,9 @@ void FuseBatchNormAddActPass::ReLinkNodes(Graph *graph, IR_OP_VAR_LINK(fused_op, out); } - nodes2delete.insert(std::move(op_1)); - nodes2delete.insert(std::move(op_2)); - nodes2delete.insert(std::move(op_3)); + nodes2delete.insert(op_1); + nodes2delete.insert(op_2); + nodes2delete.insert(op_3); GraphSafeRemoveNodes(graph, nodes2delete); } diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 60267177b7a2b3..3c550ca84042d2 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -352,7 +352,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { if (out->Name() == intermediate_out_args[0]) { if (out->outputs.empty()) { cur_node->outputs = this->RemoveNode(out, cur_node->outputs); - need_removed_nodes.insert(std::move(out)); + need_removed_nodes.insert(out); cur_node->Op()->SetAttr("save_intermediate_out", false); } } @@ -373,7 +373,7 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { out->outputs.empty()) { cur_node->Op()->SetOutput(GradVarName("IntermediateOut"), {}); cur_node->outputs = this->RemoveNode(out, cur_node->outputs); - need_removed_nodes.insert(std::move(out)); + need_removed_nodes.insert(out); } } } @@ -439,8 +439,8 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, IR_OP_VAR_LINK(fused_op, out); } - nodes2delete.insert(std::move(op_1)); - nodes2delete.insert(std::move(op_2)); + nodes2delete.insert(op_1); + nodes2delete.insert(op_2); GraphSafeRemoveNodes(graph, nodes2delete); } @@ -485,8 +485,8 @@ void FuseElewiseAddActPass::ReLinkNodes2(Graph *graph, IR_OP_VAR_LINK(fused_op, out); } - nodes2delete.insert(std::move(op_1)); - nodes2delete.insert(std::move(op_2)); + nodes2delete.insert(op_1); + nodes2delete.insert(op_2); GraphSafeRemoveNodes(graph, nodes2delete); } diff --git a/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc index 2f92a58ba3a77e..0ba4ef378a5cbe 100644 --- a/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc +++ b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc @@ -83,7 +83,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, auto matmul_op_desc = matmul_op->Op(); if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape)) return; - bool trans_x, trans_y; + bool trans_x = false, trans_y = false; GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y); OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); @@ -168,7 +168,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( auto activation = act_op->Op()->Type(); - bool trans_x, trans_y; + bool trans_x = false, trans_y = false; GetTransposeAttrsFromOp(*matmul_op_desc, &trans_x, &trans_y); OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); @@ -291,7 +291,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, auto matmul_grad_op_desc = matmul_grad_op->Op(); if (!IsGemmFromLinear_(matmul_grad_x_shape, matmul_grad_w_shape)) return; - bool trans_x, trans_y; + bool trans_x = false, trans_y = false; GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); @@ -319,10 +319,12 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, auto ele_add_grad_op_role_val = details::GetOpRoleVarsOrEmpty(*(ele_add_grad_op->Op())); std::vector fused_gemm_epilogue_grad_op_role_var; - for (auto i : matmul_grad_op_role_val) { + fused_gemm_epilogue_grad_op_role_var.reserve( + matmul_grad_op_role_val.size()); + for (auto const &i : matmul_grad_op_role_val) { fused_gemm_epilogue_grad_op_role_var.push_back(i); } - for (auto i : ele_add_grad_op_role_val) { + for (auto const &i : ele_add_grad_op_role_val) { fused_gemm_epilogue_grad_op_role_var.push_back(i); } fused_gemm_epilogue_grad_op_desc.SetAttr( @@ -430,7 +432,7 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( auto activation_grad = act_grad_op->Op()->Type(); - bool trans_x, trans_y; + bool trans_x = false, trans_y = false; GetTransposeAttrsFromOp(*matmul_grad_op_desc, &trans_x, &trans_y); OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); @@ -455,10 +457,12 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( auto ele_add_grad_op_role_val = details::GetOpRoleVarsOrEmpty(*(ele_add_grad_op->Op())); std::vector fused_gemm_epilogue_grad_op_role_var; - for (auto i : matmul_grad_op_role_val) { + fused_gemm_epilogue_grad_op_role_var.reserve( + matmul_grad_op_role_val.size()); + for (auto const &i : matmul_grad_op_role_val) { fused_gemm_epilogue_grad_op_role_var.push_back(i); } - for (auto i : ele_add_grad_op_role_val) { + for (auto const &i : ele_add_grad_op_role_val) { fused_gemm_epilogue_grad_op_role_var.push_back(i); } fused_gemm_epilogue_grad_op_desc.SetAttr( diff --git a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc index d2996fe4c64b39..4a9e316f30b2b8 100644 --- a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc +++ b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc @@ -192,7 +192,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { // Pass pre-condition check: check dtype of fusing vars auto fusing_var_dtype = GetDtypeOfVar(vars_info, aux_var_map.at(kParam).front()); - for (auto vars : aux_var_map) { + for (auto const &vars : aux_var_map) { for (auto &var_name : vars.second) { if (fusing_var_dtype != GetDtypeOfVar(vars_info, var_name)) { // Note(chenweihang): Currently the fuse_optimizer_ops strategy @@ -204,7 +204,7 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const { // Pass pre-condition check: gradients generated op kernel auto fusing_grad_var_names = aux_var_map.at(kGrad); - for (auto grad_var_name : fusing_grad_var_names) { + for (auto const &grad_var_name : fusing_grad_var_names) { if (!GradGeneratedOpKernelCheck(vars_info, grad_var_name)) { // Note(chenweihang): Currently the fuse_optimizer_ops strategy is risky // when gradient generated operator with kernel just support CPU or @@ -336,7 +336,7 @@ bool FuseOptimizerOpPass::GradGeneratedOpKernelCheck( } } } - for (auto op_type : check_op_set) { + for (auto const &op_type : check_op_set) { if (!OpWithKernelSupportCPUAndGPU(op_type)) { return false; } @@ -365,6 +365,7 @@ void FuseOptimizerOpPass::GradientsFilter( } } std::vector sorted_ops; + sorted_ops.reserve(new_grad_idx.size()); for (size_t i : new_grad_idx) { sorted_ops.emplace_back(opt_nodes->at(i)); } diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index 61d22a4b6d056f..32515b0b0eb105 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -303,7 +303,7 @@ std::string CodeGenerator::EmitParameters( output_args.push_back(args_str); } } - for (auto args : output_args) { + for (auto const& args : output_args) { ret << args; if (index != output_args.size() - 1) { ret << ", "; diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index edd4d455df8a6c..d87ef1c1e39eac 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -21,7 +21,7 @@ namespace framework { namespace ir { namespace fusion_group { -OperationMap* OperationMap::map = nullptr; +OperationMap *OperationMap::map = nullptr; OperationMap::OperationMap() { InsertUnaryElementwiseOperations(); @@ -31,7 +31,7 @@ OperationMap::OperationMap() { std::unordered_set OperationMap::Find(int type) { std::unordered_set res; - for (auto& t : operations_) { + for (auto &t : operations_) { if (t.second.type == type) { res.insert(t.first); } @@ -60,15 +60,15 @@ void OperationMap::Insert(int type, // grad_inputs = inputs + outputs + grad of outputs std::vector grad_input_names = input_names; - for (auto name : output_names) { + for (auto const &name : output_names) { grad_input_names.push_back(name); } - for (auto name : output_names) { + for (auto const &name : output_names) { grad_input_names.push_back(GradVarName(name)); } // grad_output = grad of inputs std::vector grad_output_names; - for (auto name : input_names) { + for (auto const &name : input_names) { grad_output_names.push_back(GradVarName(name)); } Operation grad_op(type, diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index e0ab584ee32256..821bed7e6d53df 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -286,7 +286,7 @@ GraphPatternDetector::handle_t GetGenerateDelete( for (const std::unique_ptr& pdnode : pattern.nodes()) { remove_nodes.emplace(subgraph.at(pdnode.get())); } - for (auto iter : var_node_maps) { + for (auto const& iter : var_node_maps) { remove_nodes.erase(iter.second); } GraphSafeRemoveNodes(graph, remove_nodes); @@ -424,7 +424,7 @@ GraphPatternDetector::handle_t GetGenerateRewrite( for (const std::unique_ptr& pdnode : pattern.nodes()) { remove_nodes.emplace(subgraph.at(pdnode.get())); } - for (auto iter : var_node_maps) { + for (auto const& iter : var_node_maps) { for (auto& node : iter.second) { remove_nodes.erase(node); } diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 5d7054721db53a..30a001777bd587 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -164,7 +164,7 @@ std::vector TopologySortOperations(const Graph &graph) { "Generated graph shouldn't contain cycle.")); std::unordered_set visited; std::vector ret; - for (auto adj : adj_list) { + for (auto const &adj : adj_list) { if (visited.find(adj.first) == visited.end()) { SortHelper(adj_list, adj.first, &visited, &ret); } @@ -449,7 +449,7 @@ std::vector TopologySortGraphByDescOrder(const Graph &graph) { "Generated graph shouldn't contain cycle.")); std::unordered_set visited; std::vector ret; - for (auto adj : adj_list) { + for (auto const &adj : adj_list) { if (visited.find(adj.first) == visited.end()) { SortHelper(adj_list, adj.first, &visited, &ret); } @@ -502,6 +502,7 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { // TODO(Ruibiao) : Set OpDeviceAttrName when needed std::vector output_names; + output_names.reserve(node.outputs.size()); for (auto out : node.outputs) { output_names.emplace_back(out->Name()); } @@ -741,7 +742,7 @@ template static void GetGraphVarDesc(const Graph &graph, const std::unordered_set &nodes, std::vector *vars) { - for (T node : nodes) { + for (T const &node : nodes) { if (node->IsVar() && node->Var() && node->GetVarNodeBlockId() == graph.GetBlockId()) { vars->emplace_back(*node->Var()->Proto()); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index 4f430ba4041d69..3f68f5c6dd72b4 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -149,7 +149,7 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const { } } } - decltype(op_attrs)* attr; + decltype(op_attrs)* attr = nullptr; if (marked_nodes.count(n)) { attr = &marked_var_attrs; } else if (const_cast(n)->Var() && diff --git a/paddle/fluid/framework/ir/inplace_op_var_pass.cc b/paddle/fluid/framework/ir/inplace_op_var_pass.cc index 5bbe980daaba7e..7648fd0c89a26c 100644 --- a/paddle/fluid/framework/ir/inplace_op_var_pass.cc +++ b/paddle/fluid/framework/ir/inplace_op_var_pass.cc @@ -85,12 +85,12 @@ std::vector InplaceOpVarPass::GetControlFlowVarNames( for (auto* node : graph->Nodes()) { if (!node->IsOp() || control_flow_ops_.count(node->Op()->Type()) == 0) continue; - for (auto in_names : node->Op()->Inputs()) { + for (auto const& in_names : node->Op()->Inputs()) { auto var_names = in_names.second; control_flow_var_names.insert( control_flow_var_names.end(), var_names.begin(), var_names.end()); } - for (auto out_names : node->Op()->Outputs()) { + for (auto const& out_names : node->Op()->Outputs()) { auto var_names = out_names.second; control_flow_var_names.insert( control_flow_var_names.end(), var_names.begin(), var_names.end()); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index 40525a14141a6a..0398117e08b8fb 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -86,7 +86,7 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars( lod_tensors->clear(); other_vars->clear(); - for (auto &op_vars_pair : m) { + for (auto const &op_vars_pair : m) { for (auto var_name : op_vars_pair.second) { auto *var_desc = TryGetLatestVarDesc( vars[op_vars_pair.first->GetScopeIdx()].at(var_name)); @@ -247,7 +247,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { op->GetScope(), op->GetScopeIdx(), op->GetPlace(), - std::move(var_info), + var_info, gcs.at(places[op->GetScopeIdx()]).get()); auto it = std::find_if( diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc index a4cc550938495d..9c60a665de0021 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc @@ -123,6 +123,7 @@ class ReferenceCountPassTestHelper { std::vector LastLivedOps(const std::string &name) const { auto &ops = last_live_ops_of_vars_[0].at(name).ops(); std::vector ret; + ret.reserve(ops.size()); for (auto *op : ops) { ret.emplace_back(op->GetOp()); } diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index eee24e01a328b3..1738259d60f004 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -412,7 +412,7 @@ std::unordered_set ComputePropagateScalesMkldnnPass::UpdateScales( auto out_iter = var_quant_scales->find(op_node->Op()->Output("Out")[0]); if (out_iter != var_quant_scales->end()) { std::vector input_names = op_node->Op()->Input("X"); - for (auto input_name : input_names) { + for (auto const& input_name : input_names) { auto concat_in_iter = var_quant_scales->find(input_name); if (concat_in_iter == var_quant_scales->end()) (*var_quant_scales)[input_name] = out_iter->second; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 9c9ea82445d60b..8d8504708f0373 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -349,14 +349,14 @@ bool CPUQuantizePass::AreScalesPresentForVarNames( bool present = true; if (var_quant_scales_->empty()) { auto& scales = Get("quant_var_scales"); - for (auto name : names) { + for (auto const& name : names) { if (scales.find(name) == scales.end()) { present = false; LogScaleIsMissingForVarName(name); } } } else { - for (auto name : names) { + for (auto const& name : names) { if (var_quant_scales_->find(name) == var_quant_scales_->end()) { present = false; LogScaleIsMissingForVarName(name); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 052c26ba8e2681..c9461060e443f2 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -202,7 +202,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( if (dequant_scale == quant_scale && dequant_shift == quant_shift) { // squash dequantize-quantize to nothing auto quant_out_var_name = quant_out->Name(); - for (auto input_name : next_op_desc->InputNames()) { + for (auto const& input_name : next_op_desc->InputNames()) { auto& input_names = next_op_desc->MutableInputs()->at(input_name); std::replace(input_names.begin(), input_names.end(), diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc index d007ef16d33ec2..47c76289d187c4 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -27,7 +27,7 @@ using string::PrettyLogDetail; void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { auto act_types = GetSupportedActivations(); - for (auto act_type : act_types) FuseFCAct(graph, act_type); + for (auto const &act_type : act_types) FuseFCAct(graph, act_type); } void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc index 655183dc712c02..0087886c1c8d7b 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc @@ -224,7 +224,7 @@ void QuantDequantMkldnnPass::CollectOutputScalesFromAttr( auto var_name_map = op_desc->Outputs(); for (auto& item : var_name_map) { - for (auto var_name : item.second) { + for (auto const& var_name : item.second) { var_quant_scales->insert(std::make_pair(var_name, scale_v)); } } diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc index d5d3804e75ca30..e35e5d297db9b9 100644 --- a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -130,7 +130,7 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { auto node = forward_backward_ops[node_idx]; OpDesc repeated_op(*(node->Op()), node->Op()->Block()); // 3. rename grad outputs to current repeat. - for (auto outname : repeated_op.OutputArgumentNames()) { + for (auto const& outname : repeated_op.OutputArgumentNames()) { if (grad_names.find(outname) != grad_names.end()) { std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i); repeated_op.RenameOutput(outname, new_gname); @@ -244,11 +244,12 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { // 5. create GRAD merge op node: sum(repeat.0...repeat.n) -> // scale(1/num_repeats) - for (auto kv : grad_repeated_map) { + for (auto const& kv : grad_repeated_map) { OpDesc sum_op; sum_op.SetType("sum"); std::vector repeated_grad_names; std::vector param_grad_op_role_var; + repeated_grad_names.reserve(kv.second.size()); for (auto r : kv.second) { repeated_grad_names.push_back(r->Var()->Name()); } diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc index 6b4e786a5aae9e..dc18979260f928 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc @@ -51,7 +51,7 @@ class FuseAllReduceOpPass : public ir::Pass { size_t num_of_all_reduce = params_grads.size(); std::unordered_set grads; grads.reserve(num_of_all_reduce); - for (auto p_g : params_grads) { + for (auto const &p_g : params_grads) { grads.insert(p_g.second); } diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc index 35f4e4830d882b..85f62c4a293fce 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc @@ -351,7 +351,7 @@ bool MultiDevSSAGraphBuilderBase::NeedCollectiveForGrad( // NOTE: This is for the case that all gradients should add collective ops for (auto *node : ops) { if (node->Op()->Type() != "allreduce") continue; - for (auto in_name : node->Op()->InputArgumentNames()) { + for (auto const &in_name : node->Op()->InputArgumentNames()) { if (in_name == grad_name) { return false; } @@ -862,7 +862,7 @@ int BalanceVarSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { size_t BalanceVarSSAGraphBuilder::GetAppropriateDeviceID( const std::vector &var_names) const { int64_t numel_sum = 0; - for (auto var_name : var_names) { + for (auto const &var_name : var_names) { if (all_vars_.find(var_name) == all_vars_.end()) continue; auto var_desc = all_vars_.at(var_name); PADDLE_ENFORCE_NOT_NULL(var_desc, @@ -1137,6 +1137,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { details::BuildStrategy::ReduceStrategy::kAllReduce && node->inputs[0]->Name().find(".block") == std::string::npos) { std::vector input_var_names; + input_var_names.reserve(node->inputs.size()); for (ir::Node *n : node->inputs) { input_var_names.push_back(n->Name()); } @@ -1162,6 +1163,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { } } else if (node->Op()->Type() == "recv") { std::vector output_var_names; + output_var_names.reserve(node->inputs.size()); for (ir::Node *n : node->outputs) { output_var_names.push_back(n->Name()); } @@ -1245,6 +1247,8 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, int op_dev_id = -1; std::vector input_var_names; std::vector output_var_names; + input_var_names.reserve(node->inputs.size()); + output_var_names.reserve(node->outputs.size()); for (ir::Node *input : node->inputs) { input_var_names.push_back(input->Name()); } diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index a950ec191a4bf7..0fd3a71754f6d9 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -273,9 +273,9 @@ PDNode* MultiHeadMatmulPattern::operator()() { auto* mul0_out_var = pattern->NewNode(mul0_out_repr())->assert_is_ops_output(mul_ops); - decltype(mul0) eltadd0; - decltype(mul0) eltadd0_b_var; - decltype(mul0) eltadd0_out_var; + decltype(mul0) eltadd0 = nullptr; + decltype(mul0) eltadd0_b_var = nullptr; + decltype(mul0) eltadd0_out_var = nullptr; mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); @@ -353,9 +353,9 @@ PDNode* MultiHeadMatmulPattern::operator()() { auto* mul1_out_var = pattern->NewNode(mul1_out_repr())->assert_is_ops_output(mul_ops); - decltype(mul1) eltadd1; - decltype(mul1) eltadd1_b_var; - decltype(mul1) eltadd1_out_var; + decltype(mul1) eltadd1 = nullptr; + decltype(mul1) eltadd1_b_var = nullptr; + decltype(mul1) eltadd1_out_var = nullptr; mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); @@ -389,9 +389,9 @@ PDNode* MultiHeadMatmulPattern::operator()() { auto* mul2_out_var = pattern->NewNode(mul2_out_repr())->assert_is_ops_output(mul_ops); - decltype(mul2) eltadd2; - decltype(mul2) eltadd2_b_var; - decltype(mul2) eltadd2_out_var; + decltype(mul2) eltadd2 = nullptr; + decltype(mul2) eltadd2_b_var = nullptr; + decltype(mul2) eltadd2_out_var = nullptr; mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); @@ -465,9 +465,9 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { auto* mul0_out_var = pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops); - decltype(mul0) eltadd0; - decltype(mul0) eltadd0_b_var; - decltype(mul0) eltadd0_out_var; + decltype(mul0) eltadd0 = nullptr; + decltype(mul0) eltadd0_b_var = nullptr; + decltype(mul0) eltadd0_out_var = nullptr; mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); @@ -539,9 +539,9 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { auto* mul1_out_var = pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops); - decltype(mul1) eltadd1; - decltype(mul1) eltadd1_b_var; - decltype(mul1) eltadd1_out_var; + decltype(mul1) eltadd1 = nullptr; + decltype(mul1) eltadd1_b_var = nullptr; + decltype(mul1) eltadd1_out_var = nullptr; mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); @@ -575,9 +575,9 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() { auto* mul2_out_var = pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops); - decltype(mul2) eltadd2; - decltype(mul2) eltadd2_b_var; - decltype(mul2) eltadd2_out_var; + decltype(mul2) eltadd2 = nullptr; + decltype(mul2) eltadd2_b_var = nullptr; + decltype(mul2) eltadd2_out_var = nullptr; mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); diff --git a/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc index c974f5fafd68b3..be5fad23fd6e2d 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_roformer_fuse_pass.cc @@ -53,9 +53,9 @@ PDNode* MultiHeadMatmulRoformerPattern::operator()() { auto* mul0_out_var = pattern->NewNode(mul0_out_repr())->assert_is_ops_output(matmul_ops); - decltype(mul0) eltadd0; - decltype(mul0) eltadd0_b_var; - decltype(mul0) eltadd0_out_var; + decltype(mul0) eltadd0 = nullptr; + decltype(mul0) eltadd0_b_var = nullptr; + decltype(mul0) eltadd0_out_var = nullptr; mul0_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); @@ -165,9 +165,9 @@ PDNode* MultiHeadMatmulRoformerPattern::operator()() { auto* mul1_out_var = pattern->NewNode(mul1_out_repr())->assert_is_ops_output(matmul_ops); - decltype(mul1) eltadd1; - decltype(mul1) eltadd1_b_var; - decltype(mul1) eltadd1_out_var; + decltype(mul1) eltadd1 = nullptr; + decltype(mul1) eltadd1_b_var = nullptr; + decltype(mul1) eltadd1_out_var = nullptr; mul1_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); eltadd1 = pattern->NewNode(eltadd1_repr())->assert_is_op("elementwise_add"); @@ -232,9 +232,9 @@ PDNode* MultiHeadMatmulRoformerPattern::operator()() { auto* mul2_out_var = pattern->NewNode(mul2_out_repr())->assert_is_ops_output(matmul_ops); - decltype(mul2) eltadd2; - decltype(mul2) eltadd2_b_var; - decltype(mul2) eltadd2_out_var; + decltype(mul2) eltadd2 = nullptr; + decltype(mul2) eltadd2_b_var = nullptr; + decltype(mul2) eltadd2_out_var = nullptr; mul2_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); eltadd2 = pattern->NewNode(eltadd2_repr())->assert_is_op("elementwise_add"); diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc index a0693e8a394338..d4e8a1683ed18a 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc @@ -114,7 +114,7 @@ TEST(SeqPoolConcatFusePass, basic) { std::vector({"j"})); std::unique_ptr graph(new ir::Graph(prog)); - int before, after; + int before = 0, after = 0; graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); // Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op // Add 1 Node: fusion_seqpool_concat @@ -168,7 +168,7 @@ TEST(SeqPoolConcatFusePass, advanced) { std::vector({"h"})); std::unique_ptr graph(new ir::Graph(prog)); - int before, after; + int before = 0, after = 0; graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); // Remove 7 Nodes: op1, op2, c, d, e, f concat_op // Add 1 Node: fusion_seqpool_concat @@ -204,7 +204,7 @@ TEST(SeqPoolConcatFusePass, more_inputs) { for (int num : {1, 2, 10}) { ProgramDesc prog = BuildProgramDesc(num); std::unique_ptr graph(new ir::Graph(prog)); - int before, after; + int before = 0, after = 0; graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); // Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op // Add Node: fusion_seqpool_concat op diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc index 16296d83dae1c1..eeef9c73db3d71 100644 --- a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc @@ -145,7 +145,7 @@ void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const { std::vector subgraph_ins_name; std::unordered_set marked_nodes; - Node* cvm_input_of_cvm; + Node* cvm_input_of_cvm = nullptr; Node* concat_out_var = concat_node->outputs[0]; GraphPatternDetector::handle_t handler = diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc index f3adab84d3a3da..390a6fc0706dfc 100644 --- a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass_tester.cc @@ -151,7 +151,7 @@ TEST(SeqPoolCVMConcatFusePass, basic) { std::vector({"m"})); std::unique_ptr graph(new ir::Graph(prog)); - int before, after; + int before = 0, after = 0; graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); // Remove 16 Nodes: op1, op2, op3, op4, op5, op6, d, e, f, g, h, i, j, k, l, // concat_op @@ -219,7 +219,7 @@ TEST(SeqPoolCVMConcatFusePass, advanced) { std::vector({"j"})); std::unique_ptr graph(new ir::Graph(prog)); - int before, after; + int before = 0, after = 0; graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); // Remove 11 Nodes: op1, op2, op4, op5, c, d, e, f, h, i, concat_op // Add 1 Node: fusion_seqpool_cvm_concat @@ -265,7 +265,7 @@ TEST(SeqPoolCVMConcatFusePass, more_inputs) { for (int num : {1, 2, 10}) { ProgramDesc prog = BuildProgramDesc(num); std::unique_ptr graph(new ir::Graph(prog)); - int before, after; + int before = 0, after = 0; graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); // Remove Nodes: n * (seqpool_op, seqpool_out, out_unused, cvm_op, cvm_out), // and concat_op diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index 9b3eb12c3eef7b..b300dcd76119c9 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -68,7 +68,7 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, return nullptr; } for (auto* var : x->inputs) { - for (auto name : x->Op()->Input(arg_name)) { + for (auto const& name : x->Op()->Input(arg_name)) { if (var->Name() == name) { return var; } diff --git a/paddle/fluid/framework/ir/subgraph_detector.cc b/paddle/fluid/framework/ir/subgraph_detector.cc index a6576203235923..d15a117e1a38a0 100644 --- a/paddle/fluid/framework/ir/subgraph_detector.cc +++ b/paddle/fluid/framework/ir/subgraph_detector.cc @@ -234,6 +234,7 @@ void FlexibleDFS(const std::vector &source, } FNode; std::vector stack; + stack.reserve(source.size()); for (auto &node : source) { stack.push_back(FNode{node, false}); } diff --git a/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc b/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc index daf4e8ca3204a0..da39950280320d 100644 --- a/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc +++ b/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc @@ -61,6 +61,7 @@ void TransferLayoutElimPass::PutTranferlayoutAfterOp( // group_norm has 3 inputs, but we do not need there is a transfer_layout // before Bias and Scale so we extract useful_var1s from op_node->inputs. std::vector useful_var1s; + useful_var1s.reserve(op_node->inputs.size()); for (auto var1 : op_node->inputs) { // if (var1->inputs.size() >= 1 && // var1->inputs[0]->Op()->Type() == "transfer_layout") { diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc index 64f2801bf0220e..6774a6baae0230 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -69,6 +69,7 @@ void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse( GraphPatternDetector gpd; std::vector input_nodes; + input_nodes.reserve(times); for (int i = 0; i < times; i++) { input_nodes.push_back(gpd.mutable_pattern() ->NewNode("x" + std::to_string(i)) @@ -166,6 +167,7 @@ void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse( int concat_axis = PADDLE_GET_CONST(int, concat_op->Op()->GetAttr("axis")); std::string output_name = concat_out->Name(); + input_names.reserve(times); for (int i = 0; i < times; i++) { input_names.push_back(nodes[i * kNumFields]->Name()); } diff --git a/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc index 6c6174b9267016..251cf1b02e9d80 100644 --- a/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc @@ -254,6 +254,7 @@ void TrtDeleteWeightQuantDequantLinearOpPass::ApplyImpl( float* weight_scale_data = weight_scale_tensor->data(); auto weight_scale_nums = weight_scale_tensor->numel(); + weight_scale.reserve(weight_scale_nums); for (int i = 0; i < weight_scale_nums; i++) { weight_scale.push_back(weight_scale_data[i] / static_cast(range)); } diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 32675d5fa09c1b..96cff2521dfe7c 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -235,7 +235,7 @@ void SerializeToStream(std::ostream &os, void SerializeToStream(std::ostream &os, const phi::DenseTensor &tensor) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - const platform::DeviceContext *dev_ctx; + const platform::DeviceContext *dev_ctx = nullptr; auto place = tensor.place(); dev_ctx = pool.Get(place); SerializeToStream(os, tensor, *dev_ctx); @@ -243,7 +243,7 @@ void SerializeToStream(std::ostream &os, const phi::DenseTensor &tensor) { void DeserializeFromStream(std::istream &os, phi::DenseTensor *tensor) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - const platform::DeviceContext *dev_ctx; + const platform::DeviceContext *dev_ctx = nullptr; dev_ctx = pool.Get(platform::CPUPlace()); DeserializeFromStream(os, tensor, *dev_ctx); } @@ -255,7 +255,7 @@ void DeserializeFromStream(std::istream &is, const std::vector &shape) { { // the 1st field, unit32_t version for DenseTensor - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version), true, @@ -271,7 +271,7 @@ void DeserializeFromStream(std::istream &is, } { // the 2st field, LoD information - uint64_t lod_level; + uint64_t lod_level = 0; is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); auto &lod = *tensor->mutable_lod(); lod.resize(lod_level); @@ -286,7 +286,7 @@ void DeserializeFromStream(std::istream &is, const platform::DeviceContext &dev_ctx) { { // the 1st field, unit32_t version for DenseTensor - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(paddle::framework::IsTensorVersionSupported(version), true, @@ -302,12 +302,12 @@ void DeserializeFromStream(std::istream &is, } { // the 2st field, LoD information - uint64_t lod_level; + uint64_t lod_level = 0; is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); auto &lod = *tensor->mutable_lod(); lod.resize(lod_level); for (uint64_t i = 0; i < lod_level; ++i) { - uint64_t size; + uint64_t size = 0; is.read(reinterpret_cast(&size), sizeof(size)); std::vector tmp(size / sizeof(size_t)); is.read(reinterpret_cast(tmp.data()), diff --git a/paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.cc b/paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.cc index e7efc1f10c3243..e9df08d4698e28 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.cc +++ b/paddle/fluid/framework/new_executor/garbage_collector/fast_garbage_collector.cc @@ -53,6 +53,23 @@ void InterpreterCoreFastGarbageCollector::Add(Variable* var) { for (auto& t : *tensor_arr) { Add(t.MoveMemoryHolder()); } + } else if (var->IsType()) { + Add(var->GetMutable() + ->mutable_indices() + ->MoveMemoryHolder()); + Add(var->GetMutable() + ->mutable_values() + ->MoveMemoryHolder()); + } else if (var->IsType()) { + Add(var->GetMutable() + ->mutable_cols() + ->MoveMemoryHolder()); + Add(var->GetMutable() + ->mutable_crows() + ->MoveMemoryHolder()); + Add(var->GetMutable() + ->mutable_values() + ->MoveMemoryHolder()); } else if (var->IsType>()) { // NOTE(@xiongkun03) conditional_op / while_op will create a STEP_SCOPE // refer to executor.cc to see what old garbage collector does. diff --git a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt index 64c14374162c6c..8621c158a23e22 100644 --- a/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/instruction/CMakeLists.txt @@ -1,7 +1,8 @@ cc_library( instruction_base SRCS instruction_base.cc phi_kernel_instruction.cc - legacy_kernel_instruction.cc cond_instruction.cc instruction_util.cc + legacy_kernel_instruction.cc cond_instruction.cc while_instruction.cc + instruction_util.cc DEPS pir_adaptor phi framework_proto) if(WITH_CINN AND NOT CINN_ONLY) diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 8841103213400d..0c6442cd1f9d3f 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -17,102 +17,111 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/hlir/framework/instruction.h" +#include "paddle/cinn/hlir/framework/new_ir_compiler.h" +#include "paddle/cinn/runtime/cuda/cuda_util.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/framework/paddle2cinn/transform_type.h" namespace paddle { namespace framework { -// TODO(Aurelius84): Think deeply what's the responsibility is it. -// Currently it assumes CinnLaunchContext role. -class JitContext { +class CinnJitInstruction::FnPtrImpl { + using CUDAJITInfo = cinn::hlir::framework::newir::CUDAJITInfo; + public: - cinn_buffer_t* GetCinnBufferOfVar(const std::string& name) { - auto res = paddle2argument_.find(name); - PADDLE_ENFORCE_NE( - res, - paddle2argument_.end(), - platform::errors::NotFound( - "Variable(%s) not found in compilation result", name)); - return static_cast(res->second); - } + explicit FnPtrImpl(const CUDAJITInfo& cuda_jit_info) + : cuda_jit_info_(cuda_jit_info) {} + void Run(const std::vector& kernel_args, void* stream) { + func_args_.clear(); + ptr_storage_.resize(kernel_args.size()); + for (size_t i = 0; i < kernel_args.size(); ++i) { + ptr_storage_[i] = kernel_args[i]->data(); + func_args_.push_back(ptr_storage_.data() + i); + } - // NOTE(Aurelius84): Before running each instruction, we should share Tensor - // memory from paddle scope with cinn_buffer_t from cinn scope including - // inputs and outputs. - void ShareMemToCinn(const std::string& var_name, - const phi::Place& place, - Scope* scope) { - cinn_buffer_t* buffer = GetCinnBufferOfVar(var_name); - auto* tensor = scope->GetVar(var_name)->GetMutable(); - // TODO(Aurelius84): Maybe we should consider to unify the Scope - // structure between paddle and cinn, so that we don't need to develop - // the glue code. - buffer->memory = reinterpret_cast(tensor->mutable_data( - place, paddle2cinn::TransToPaddleDataType(buffer->type))); + CUDA_DRIVER_CALL( + cuLaunchKernel(static_cast(cuda_jit_info_.fn_ptr), + cuda_jit_info_.grid_dims[0], + cuda_jit_info_.grid_dims[1], + cuda_jit_info_.grid_dims[2], + cuda_jit_info_.block_dims[0], + cuda_jit_info_.block_dims[1], + cuda_jit_info_.block_dims[2], + 0, // share memory + static_cast(stream), + func_args_.data(), + nullptr)) } - // TODO(Aurelius84): Add logic to parse stream for different device. - void* GetStream() { return nullptr; } - private: - // because a cinn_pod_value_t does not own a cinn_buffer_t object, - // an extra stroage is necessary to keep those objects and they can - // not be released until the runtime program finish execution. - std::vector> hold_buffers_; - // this map saves all execution arguments with their cinn names as key, - // and it is passed to the Execute interface of a cinn runtime program. - std::map name2argument_; - // this map saves all execution arguments with paddle variables as key, - // this map conbine name2argument_ and paddle2cinn_varmap_ - std::map paddle2argument_; -}; + CUDAJITInfo cuda_jit_info_; -// TODO(Aurelius84): Impl should hold JitContext instance to -// deliver the device context for 'instr->Run' and responsible -// to deal with inner buffer_t shareing between framework::Scope -// and cinn::Scope. -class CinnJitInstruction::Impl { - using Instruction = cinn::hlir::framework::Instruction; - - public: - explicit Impl(Instruction* instr) : instr_(instr) {} - // TODO(Aurelus84): Support to specify name2podargs and stream arguments. - void Run() { - PADDLE_ENFORCE_NOT_NULL( - instr_, platform::errors::NotFound("instr_ should not be NULL")); - instr_->Run(/*name2podargs=*/nullptr, - false, - /*stream=*/nullptr, - /*use_cache=*/true); - } - const Instruction* pointer() const { return instr_; } - - private: - Instruction* instr_{nullptr}; + std::vector ptr_storage_; + std::vector func_args_; }; -CinnJitInstruction::CinnJitInstruction(size_t id, - const platform::Place& place, - ::pir::Operation* op, - Scope* scope) +CinnJitInstruction::CinnJitInstruction( + size_t id, + const platform::Place& place, + ::pir::Operation* op, + const ValueExecutionInfo& value_exec_info) : InstructionBase(id, place) { - // TODO(Aurelius84): We shall simplify members of JitKernelOp to make it - // only hold related function ptrs. Impl is the real runtime data structure - // responsible to construct hlir::framework::Instruction. auto jit_kernel_op = op->dyn_cast(); - impl_ = std::make_shared(jit_kernel_op.instruction()); + fn_ptr_impl_ = std::make_shared(jit_kernel_op.cuda_jit_info()); op_ = op; + + place_ = place; + + InitInputsOutputsIds(op, value_exec_info); + + for (size_t i = 0; i < op->num_operands(); ++i) { + auto in = op->operand_source(i); + + auto var_name = value_exec_info.GetVarName(in); + + auto tensor = value_exec_info.GetScope() + ->Var(var_name) + ->GetMutable(); + + tensor_args_.push_back(tensor); + } + + dev_ctx_ = phi::DeviceContextPool::Instance().Get(place_); + + for (size_t i = 0; i < op->num_results(); ++i) { + pir::Value result = op->result(i); + auto var_name = value_exec_info.GetVarName(result); + + auto tensor = value_exec_info.GetScope() + ->Var(var_name) + ->GetMutable(); + + tensor_args_.push_back(tensor); + + out_tensor_ = tensor; + + auto alloc_tensor_type = + result.type().dyn_cast(); + tensor->set_type( + paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype())); + tensor->Resize(alloc_tensor_type.dims()); + } } void CinnJitInstruction::Run() { - VLOG(6) << "Run cinn jit_kernel_op : " << Name(); - impl_->Run(); + auto gpu_ctx = static_cast(dev_ctx_); + + auto stream = gpu_ctx->stream(); + for (size_t i = 0; i < tensor_args_.size(); ++i) { + gpu_ctx->Alloc(tensor_args_[i], tensor_args_[i]->dtype()); + } + + fn_ptr_impl_->Run(tensor_args_, static_cast(stream)); } const std::string& CinnJitInstruction::Name() const { - // TODO(Aurelius84): Consider the case for instrucitons constaning - // multipule function ptrs and function names. - return impl_->pointer()->function_name(); + static const std::string name = "cinn_jit"; + return name; } } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h index 5f5e4f74e88848..ceb4014f044a6f 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h @@ -30,7 +30,7 @@ class CinnJitInstruction : public InstructionBase { CinnJitInstruction(size_t id, const platform::Place& place, ::pir::Operation* op, - Scope* scope); + const ValueExecutionInfo& value_exec_info); // TODO(Aurelius84): Only implement core interface and need implement GC and // Event logic. @@ -41,8 +41,17 @@ class CinnJitInstruction : public InstructionBase { ::pir::Operation* Operation() const override { return op_; } private: - class Impl; - std::shared_ptr impl_{nullptr}; + class FnPtrImpl; + + std::shared_ptr fn_ptr_impl_{nullptr}; + + platform::Place place_; + + phi::DeviceContext* dev_ctx_; + + phi::DenseTensor* out_tensor_; + + std::vector tensor_args_; ::pir::Operation* op_{nullptr}; // not owned }; diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index bcdb7abb7eec1c..2422597ece0d1a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -39,64 +39,6 @@ namespace paddle { namespace framework { -std::vector GetYiedOpInputs(pir::Block* block) { - std::vector vec_res; - for (auto op : (*block)) { - if (op->name() == "cf.yield") { - for (size_t i = 0; i < op->num_operands(); ++i) { - vec_res.push_back(op->operand_source(i)); - } - } - } - return vec_res; -} - -void GetInputIds(pir::Operation* op, - const ValueExecutionInfo& value_exec_info, - std::unordered_map>* input_ids) { - for (size_t i = 0; i < op->num_operands(); i++) { - pir::Value value = op->operand_source(i); - if (value && value.type()) { - PADDLE_ENFORCE_EQ( - value_exec_info.HasValue(value), - true, - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - "if op")); - input_ids->emplace(value, GetValueIds(value, value_exec_info)); - } - } -} - -void GetOutsideOpInputs( - pir::Block* block, - const ValueExecutionInfo& value_exec_info, - std::unordered_map>* input_ids) { - std::unordered_set inner_outputs; - for (auto op : (*block)) { - for (size_t i = 0; i < op->num_results(); ++i) { - inner_outputs.insert(op->result(i)); - } - } - - for (auto op : (*block)) { - for (size_t i = 0; i < op->num_operands(); ++i) { - pir::Value value = op->operand_source(i); - if (value && (!inner_outputs.count(value))) { - PADDLE_ENFORCE_EQ( - value_exec_info.HasValue(value), - true, - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - op->name())); - input_ids->emplace(value, GetValueIds(value, value_exec_info)); - } - } - } -} - CondInstruction::CondInstruction(size_t id, const platform::Place& place, pir::Operation* op, @@ -120,8 +62,36 @@ CondInstruction::CondInstruction(size_t id, } VLOG(6) << "finish process cond_var and output_vars"; + // NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is + // OpOperand of IfOp, and the other is external Values used in true_block or + // false_block. auto true_branch_block = if_op.true_block(); - auto true_branch_yied_inputs = GetYiedOpInputs(true_branch_block); + auto false_branch_block = if_op.false_block(); + std::unordered_map> inputs; + GetInputIds(op, *value_exec_info, &inputs); + auto true_outside_inputs = + GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); + auto false_outside_inputs = + GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); + SetInputs(inputs); + + std::unordered_map> outputs; + for (size_t i = 0; i < op->num_results(); i++) { + pir::Value value = op->result(i); + if (value && value.type()) { + PADDLE_ENFORCE_EQ( + value_exec_info->HasValue(value), + true, + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + "if op")); + outputs.emplace(value, GetValueIds(value, *value_exec_info)); + } + } + SetOutputs(outputs); + VLOG(6) << "finish process inputs outputs index"; + Scope* true_scope = &(value_exec_info->GetScope()->NewScope()); true_branch_inter_ = new NewIRInterpreter(place, @@ -132,15 +102,20 @@ CondInstruction::CondInstruction(size_t id, {}); std::set true_skip_gc_names_set; - for (auto value : true_branch_yied_inputs) { + for (auto value : GetYiedOpInputs(true_branch_block)) { + true_branch_outputs_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); + true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); + } + // NOTE(zhangbo): According to the concept of control flow, child scopes + // should not control the lifecycle of parent scope variables. + for (auto value : true_outside_inputs) { true_skip_gc_names_.push_back(true_branch_inter_->GetNameByValue(value)); true_skip_gc_names_set.insert(true_branch_inter_->GetNameByValue(value)); } true_branch_inter_->SetSkipGcVars(true_skip_gc_names_set); VLOG(6) << "finish process true branch interpreter"; - auto false_branch_block = if_op.false_block(); - auto false_branch_yied_inputs = GetYiedOpInputs(false_branch_block); Scope* false_scope = &(value_exec_info->GetScope()->NewScope()); false_branch_inter_ = new NewIRInterpreter(place, @@ -151,38 +126,17 @@ CondInstruction::CondInstruction(size_t id, {}); std::set false_skip_gc_names_set; - for (auto value : false_branch_yied_inputs) { + for (auto value : GetYiedOpInputs(false_branch_block)) { + false_branch_outputs_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); + false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); + } + for (auto value : false_outside_inputs) { false_skip_gc_names_.push_back(false_branch_inter_->GetNameByValue(value)); false_skip_gc_names_set.insert(false_branch_inter_->GetNameByValue(value)); } false_branch_inter_->SetSkipGcVars(false_skip_gc_names_set); VLOG(6) << "finish process false branch interpreter"; - - // NOTE(zhangbo): IfOp sub_block's inputs include two kind of value: one is - // OpOperand of IfOp, and the other is external Values used in true_block or - // false_block. - std::unordered_map> inputs; - GetInputIds(op, *value_exec_info, &inputs); - GetOutsideOpInputs(true_branch_block, *value_exec_info, &inputs); - GetOutsideOpInputs(false_branch_block, *value_exec_info, &inputs); - SetInputs(inputs); - - std::unordered_map> outputs; - for (size_t i = 0; i < op->num_results(); i++) { - pir::Value value = op->result(i); - if (value && value.type()) { - PADDLE_ENFORCE_EQ( - value_exec_info->HasValue(value), - true, - phi::errors::PreconditionNotMet( - "input should in name map, [%d] 'th input of [%s] op", - i, - "if op")); - outputs.emplace(value, GetValueIds(value, *value_exec_info)); - } - } - SetOutputs(outputs); - VLOG(6) << "finish process inputs outputs index"; } CondInstruction::~CondInstruction() { @@ -208,10 +162,10 @@ void CondInstruction::Run() { DeviceContext().Wait(); if (cond_var_->Get().data()[0]) { true_branch_inter_->Run({}, false); - CopyBranchOutput(true_skip_gc_names_, true_branch_inter_); + CopyBranchOutput(true_branch_outputs_, true_branch_inter_); } else { false_branch_inter_->Run({}, false); - CopyBranchOutput(false_skip_gc_names_, false_branch_inter_); + CopyBranchOutput(false_branch_outputs_, false_branch_inter_); } // copy ouptut diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h index 1cdc4a388126a6..469c0ed0ae1ab8 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.h @@ -58,6 +58,10 @@ class CondInstruction : public InstructionBase { NewIRInterpreter* false_branch_inter_; + std::vector true_branch_outputs_; + + std::vector false_branch_outputs_; + // TODO(zhangbo): Currently, only the output of IfOp is included. In the // future, need to consider how to support IfGradOp using IfOp value. std::vector true_skip_gc_names_; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc index 0b494c29dea86d..62419acffc099f 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.cc @@ -217,8 +217,11 @@ void InstructionBase::SetOutputs( void InstructionBase::InitInputsOutputsIds( ::pir::Operation* op, const ValueExecutionInfo& value_exec_info) { auto op_attributes = op->attributes(); - auto op_name = - op_attributes.at("op_name").dyn_cast().AsString(); + std::string op_name; + if (op_attributes.count("op_name ")) { + op_name = + op_attributes.at("op_name").dyn_cast().AsString(); + } std::unordered_map> inputs; for (size_t i = 0; i < op->num_operands(); i++) { pir::Value value = op->operand_source(i); @@ -257,8 +260,7 @@ void InstructionBase::InitInputsOutputsIds( std::string InstructionBase::DebugStringEx( const paddle::framework::Scope* scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name) - const { + ValueExecutionInfo* value_exe_info) const { std::stringstream ss; ss << "Op(" << Name() << "), inputs:{"; @@ -268,7 +270,7 @@ std::string InstructionBase::DebugStringEx( auto& input = *it; bool is_no_need_buffer_var = (!no_need_buffer_vars.empty() && no_need_buffer_vars.count(input.first) > 0); - auto var_name = value_2_var_name.at(input.first); + auto var_name = value_exe_info->GetVarName(input.first); ss << var_name; if (scope) { if (!VarInited(*scope, var_name)) { @@ -296,7 +298,7 @@ std::string InstructionBase::DebugStringEx( ss << "}, outputs:{"; for (auto it = Outputs().begin(); it != Outputs().end();) { auto& output = *it; - auto var_name = value_2_var_name.at(output.first); + auto var_name = value_exe_info->GetVarName(output.first); ss << var_name; if (scope) { if (!VarInited(*scope, var_name)) { diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index 60797426119154..5dd7ff3e4d2a5d 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -144,10 +144,8 @@ class InstructionBase { const ValueExecutionInfo& value_exec_info); // if scope is not null, also show dimensions of arguments - virtual std::string DebugStringEx( - const paddle::framework::Scope* scope, - const std::unordered_map<::pir::Value, std::string>& value_2_var_name) - const; + virtual std::string DebugStringEx(const paddle::framework::Scope* scope, + ValueExecutionInfo* value_exe_info) const; protected: size_t id_; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index 6459c094a8e7b1..4066bc7afb3dc6 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/core/distributed/comm_context_manager.h" @@ -190,5 +191,66 @@ OpFuncType AnalyseOpFuncType(pir::Operation* op, const platform::Place& place) { return OpFuncType::kGpuAsync; } +std::vector GetYiedOpInputs(pir::Block* block) { + std::vector vec_res; + + if (block && !block->empty() && block->back()->isa()) { + auto* op = block->back(); + for (size_t i = 0; i < op->num_operands(); ++i) { + vec_res.emplace_back(op->operand_source(i)); + } + } + return vec_res; +} + +void GetInputIds(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* input_ids) { + for (size_t i = 0; i < op->num_operands(); i++) { + pir::Value value = op->operand_source(i); + if (value && value.type()) { + PADDLE_ENFORCE_EQ( + value_exec_info.HasValue(value), + true, + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + "if op")); + input_ids->emplace(value, GetValueIds(value, value_exec_info)); + } + } +} + +std::vector GetOutsideOpInputs( + pir::Block* block, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* input_ids) { + std::unordered_set inner_outputs; + for (auto op : (*block)) { + for (size_t i = 0; i < op->num_results(); ++i) { + inner_outputs.insert(op->result(i)); + } + } + + std::vector outside_op_inputs; + for (auto op : (*block)) { + for (size_t i = 0; i < op->num_operands(); ++i) { + pir::Value value = op->operand_source(i); + if (value && (!inner_outputs.count(value))) { + PADDLE_ENFORCE_EQ( + value_exec_info.HasValue(value), + true, + phi::errors::PreconditionNotMet( + "input should in name map, [%d] 'th input of [%s] op", + i, + op->name())); + input_ids->emplace(value, GetValueIds(value, value_exec_info)); + outside_op_inputs.push_back(value); + } + } + } + return outside_op_inputs; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.h b/paddle/fluid/framework/new_executor/instruction/instruction_util.h index dd1b98fa3dc15e..8304b134e05341 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.h @@ -43,5 +43,16 @@ platform::DeviceContext* ParseDeviceContext( OpFuncType AnalyseOpFuncType(::pir::Operation* op, const platform::Place& place); +std::vector GetYiedOpInputs(pir::Block* block); + +void GetInputIds(pir::Operation* op, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* input_ids); + +std::vector GetOutsideOpInputs( + pir::Block* block, + const ValueExecutionInfo& value_exec_info, + std::unordered_map>* input_ids); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc index 97bda347770081..6a8ecd09c4cece 100644 --- a/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.cc @@ -166,12 +166,12 @@ LegacyKernelInstruction::~LegacyKernelInstruction() { } void LegacyKernelInstruction::Run() { + VLOG(6) << "Run op " << legacy_op_name_ << " infer meta."; if (infer_meta_interface_) { infer_meta_interface_->infer_meta_(&(infer_meta_context_)); } - VLOG(6) << "Run op " << legacy_op_name_ << " infer meta."; - (*(phi_kernel_))((kernel_context_)); VLOG(6) << "Run op " << legacy_op_name_ << " kernel."; + (*(phi_kernel_))((kernel_context_)); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc new file mode 100644 index 00000000000000..b511ad1f602320 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/new_executor/instruction/while_instruction.h" + +#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" +#include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" +#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/meta_tensor.h" +#include "paddle/phi/core/type_defs.h" + +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" + +namespace paddle { +namespace framework { + +WhileInstruction::WhileInstruction(size_t id, + const platform::Place& place, + pir::Operation* op, + Scope* scope, + Scope* local_scope, + ValueExecutionInfo* parent_exe_info) + : InstructionBase(id, place) { + op_ = op; + VLOG(6) << "finish process dist attributes"; + + SetKernelType(AnalyseOpFuncType(op, place)); + VLOG(6) << "finish process analyse kernel type"; + + Scope* inner_scope = local_scope == nullptr ? scope : local_scope; + + VLOG(6) << "finish process inputs outputs index"; + + PADDLE_ENFORCE(op->isa(), + phi::errors::PreconditionNotMet( + "While instruction only support While op")); + + auto while_op = op->dyn_cast(); + + cond_var_ = inner_scope->GetVar( + parent_exe_info->GetValue2VarName().at(while_op.operand_source(0))); + for (size_t i = 1; i < while_op.num_operands(); ++i) { + inputs_.push_back(inner_scope->GetVar( + parent_exe_info->GetValue2VarName().at(while_op.operand_source(i)))); + } + + for (size_t i = 0; i < while_op.num_results(); ++i) { + outputs_.push_back(inner_scope->GetVar( + parent_exe_info->GetValue2VarName().at(while_op.result(i)))); + } + + body_block_ = while_op.body_block(); + auto body_block_outputs = GetYiedOpInputs(body_block_); + + Scope* body_scope = &(parent_exe_info->GetScope()->NewScope()); + auto body_exe_info = parent_exe_info->NewChild(body_scope); + for (size_t i = 0; i < body_block_->args_size(); ++i) { + auto var_name = "body_block_arg_" + std::to_string(i); + body_scope->Var(var_name); + body_exe_info->Add(body_block_->argument(i), var_name); + } + body_inter_ = std::unique_ptr(new NewIRInterpreter( + place, {}, body_block_, body_scope, body_exe_info, {})); + + std::set body_skip_gc_names_set; + for (auto value : body_block_outputs) { + body_skip_gc_names_.push_back(body_inter_->GetNameByValue(value)); + body_skip_gc_names_set.insert(body_inter_->GetNameByValue(value)); + } + body_inter_->SetSkipGcVars(body_skip_gc_names_set); + + std::unordered_map> inputs; + GetInputIds(op, *parent_exe_info, &inputs); + + SetInputs(inputs); + + std::unordered_map> outputs; + for (size_t i = 0; i < op->num_results(); i++) { + pir::Value value = op->result(i); + if (value && value.type()) { + PADDLE_ENFORCE_NE( + parent_exe_info->GetValue2VarName().find(value), + parent_exe_info->GetValue2VarName().end(), + phi::errors::PreconditionNotMet( + "output should in name map, [%d] 'th output of [%s] op", + i, + "while op")); + std::vector outputs_id = GetValueIds(value, *parent_exe_info); + outputs.emplace(value, outputs_id); + } + } + SetOutputs(outputs); +} + +void WhileInstruction::CopyInputsToOutputs() { + for (size_t i = 0; i < outputs_.size(); ++i) { + outputs_[i]->GetMutable()->ShareDataWith( + inputs_[i]->Get()); + } +} + +void WhileInstruction::PassArgsToBodyBlock() { + for (size_t i = 0; i < body_block_->args_size(); ++i) { + auto block_arg = body_block_->argument(i); + auto var_name = body_inter_->GetNameByValue(block_arg); + auto* inner_var = body_inter_->local_scope()->GetVar(var_name); + inner_var->GetMutable()->ShareDataWith( + outputs_[i]->Get()); + } +} + +void WhileInstruction::GetValueFromBodyBlock() { + cond_var_->GetMutable()->ShareDataWith( + body_inter_->local_scope() + ->GetVar(body_skip_gc_names_[0]) + ->Get()); + for (size_t i = 0; i < outputs_.size(); ++i) { + auto& out_var_name = body_skip_gc_names_[i + 1]; + auto* out_var = body_inter_->local_scope()->GetVar(out_var_name); + outputs_[i]->GetMutable()->ShareDataWith( + out_var->Get()); + } +} +void WhileInstruction::Run() { + CopyInputsToOutputs(); + while (cond_var_->Get().data()[0]) { + PassArgsToBodyBlock(); + body_inter_->Run({}, false); + GetValueFromBodyBlock(); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/instruction/while_instruction.h b/paddle/fluid/framework/new_executor/instruction/while_instruction.h new file mode 100644 index 00000000000000..d486c8206c5026 --- /dev/null +++ b/paddle/fluid/framework/new_executor/instruction/while_instruction.h @@ -0,0 +1,77 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" + +namespace ir { +class Operation; +} // namespace ir + +namespace paddle { +namespace framework { +class Scope; +class Value; +class NewIRInterpreter; +class ValueExecutionInfo; + +/// The execute semantics of while op ['output' = while_op('cond', 'intput')] +/// is: +/// 'output' = 'input'; +/// while('cond') { +/// 'cond', 'output' = body_block('output'); +/// } +class WhileInstruction : public InstructionBase { + public: + WhileInstruction(size_t id, + const platform::Place& place, + ::pir::Operation* op, + Scope* scope, + Scope* local_scope, + ValueExecutionInfo* parent_exe_info); + + void Run() override; + + const std::string& Name() const override { return name_; } + + ::pir::Operation* Operation() const override { return op_; } + + private: + // 'output' = 'input' + void CopyInputsToOutputs(); + + // Pass argument to body_block for execution. + void PassArgsToBodyBlock(); + + // Get return value from body_block after each execution. + void GetValueFromBodyBlock(); + + std::string name_{"while_instruction"}; + + Variable* cond_var_; + + std::vector inputs_; + std::vector outputs_; + + std::unique_ptr body_inter_; + std::vector body_skip_gc_names_; + + ::pir::Block* body_block_; + + ::pir::Operation* op_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc index 0baa62f8a4dcdb..0d3af1e55c2a01 100644 --- a/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/data_transfer.cc @@ -508,7 +508,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, const std::string var_name = argument_names[i]; Variable* var = arguments->at(i); - const phi::DenseTensor* tensor_in; + const phi::DenseTensor* tensor_in = nullptr; if (var->IsType() || var->IsType()) { tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc index de77780abc3e53..4ce8c411a10b25 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc @@ -42,7 +42,7 @@ namespace interpreter { size_t CountDownstreamMap( const std::map>& downstream_map) { size_t count = 0; - for (auto pair : downstream_map) { + for (auto const& pair : downstream_map) { count += pair.second.size(); } return count; @@ -50,7 +50,7 @@ size_t CountDownstreamMap( const std::string StringizeDownstreamMap( const std::map>& downstream_map) { std::ostringstream oss; - for (auto pair : downstream_map) { + for (auto const& pair : downstream_map) { oss << pair.first << " -> "; std::copy(pair.second.begin(), pair.second.end(), @@ -144,7 +144,7 @@ void DependencyBuilder::AddDependencyForCoalesceTensorOp() { auto outputs = instructions_->at(op_idx).Outputs().at("Output"); auto is_read = [](const Instruction& inst, size_t var_id) -> bool { - for (auto pair : inst.Inputs()) { + for (auto const& pair : inst.Inputs()) { for (size_t item : pair.second) { if (item == var_id) { return true; @@ -155,7 +155,7 @@ void DependencyBuilder::AddDependencyForCoalesceTensorOp() { }; auto is_write = [](const Instruction& inst, size_t var_id) -> bool { - for (auto pair : inst.Outputs()) { + for (auto const& pair : inst.Outputs()) { for (size_t item : pair.second) { if (item == var_id) { return true; diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 3d142acdc1c7a4..fef6b91f95026c 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/distributed/auto_parallel/dist_attr.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" @@ -228,7 +229,9 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { return type == proto::VarType::LOD_TENSOR || type == proto::VarType::SELECTED_ROWS || - type == proto::VarType::LOD_TENSOR_ARRAY; + type == proto::VarType::LOD_TENSOR_ARRAY || + type == proto::VarType::SPARSE_COO || + type == proto::VarType::SPARSE_CSR; } std::unordered_map> following_ops( - ops.begin() + i + 1, ops.end()); + ops.begin() + static_cast(i) + 1, ops.end()); HandleOperatorBase(place, ops[i], &op_func_node, @@ -894,7 +897,7 @@ void BuildOpFuncList(const platform::Place& place, // avoid overwriting valid data if (static_build && original_tensor->initialized()) { const phi::Place& target_place = transformed_tensor->place(); - platform::DeviceContext* dev_ctx_for_copy; + platform::DeviceContext* dev_ctx_for_copy = nullptr; if (target_place.GetType() != AllocationType::CPU) { dev_ctx_for_copy = pool.Get(target_place); } else { @@ -934,7 +937,7 @@ void BuildOpFuncList(const platform::Place& place, } } catch (platform::EnforceNotMet& ex) { framework::InsertCallStackInfo(op_type, op->Attrs(), &ex); - throw std::move(ex); + throw ex; } catch (platform::EOFException&) { std::rethrow_exception(std::current_exception()); } catch (std::exception& ex) { @@ -1002,6 +1005,33 @@ void BuildOpFuncList(const platform::Place& place, if (var->IsType()) { garbages->emplace_back( var->GetMutable()->MoveMemoryHolder()); + } else if (var->IsType()) { + garbages->emplace_back(var->GetMutable() + ->mutable_value() + ->MoveMemoryHolder()); + var->GetMutable()->mutable_rows()->clear(); + } else if (var->IsType()) { + auto* tensor_arr = var->GetMutable(); + for (auto& t : *tensor_arr) { + garbages->emplace_back(t.MoveMemoryHolder()); + } + } else if (var->IsType()) { + garbages->emplace_back(var->GetMutable() + ->mutable_indices() + ->MoveMemoryHolder()); + garbages->emplace_back(var->GetMutable() + ->mutable_values() + ->MoveMemoryHolder()); + } else if (var->IsType()) { + garbages->emplace_back(var->GetMutable() + ->mutable_cols() + ->MoveMemoryHolder()); + garbages->emplace_back(var->GetMutable() + ->mutable_crows() + ->MoveMemoryHolder()); + garbages->emplace_back(var->GetMutable() + ->mutable_values() + ->MoveMemoryHolder()); } } delete garbages; // free mem @@ -1022,6 +1052,33 @@ void BuildOpFuncList(const platform::Place& place, if (var->IsType()) { garbages->emplace_back( var->GetMutable()->MoveMemoryHolder()); + } else if (var->IsType()) { + garbages->emplace_back(var->GetMutable() + ->mutable_value() + ->MoveMemoryHolder()); + var->GetMutable()->mutable_rows()->clear(); + } else if (var->IsType()) { + auto* tensor_arr = var->GetMutable(); + for (auto& t : *tensor_arr) { + garbages->emplace_back(t.MoveMemoryHolder()); + } + } else if (var->IsType()) { + garbages->emplace_back(var->GetMutable() + ->mutable_indices() + ->MoveMemoryHolder()); + garbages->emplace_back(var->GetMutable() + ->mutable_values() + ->MoveMemoryHolder()); + } else if (var->IsType()) { + garbages->emplace_back(var->GetMutable() + ->mutable_cols() + ->MoveMemoryHolder()); + garbages->emplace_back(var->GetMutable() + ->mutable_crows() + ->MoveMemoryHolder()); + garbages->emplace_back(var->GetMutable() + ->mutable_values() + ->MoveMemoryHolder()); } } delete garbages; @@ -1138,7 +1195,7 @@ std::unordered_set GetSpecialOpNames() { "builtin.set_parameter", "builtin.get_parameter", "pd_op.data", - "pd_op.shadow_output", + "builtin.shadow_output", }; } diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index e8e5a1ef29aedf..bebeb142d473f1 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -330,7 +330,7 @@ void FakeInitializeTensor(const platform::DeviceContext& dev_ctx, // set place if (tensor->initialized()) { // avoid overwriting valid data - platform::DeviceContext* dev_ctx_for_copy; + platform::DeviceContext* dev_ctx_for_copy = nullptr; if (place.GetType() != AllocationType::CPU) { dev_ctx_for_copy = platform::DeviceContextPool::Instance().Get(place); } else { diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index bbbaf4c0dd75f2..3f356270e05702 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -431,6 +431,7 @@ void analyse_event_info_for_two_instructions( if (has_data_dependency( instructions[cur_instr_id], instructions[next_instr_id]) || + !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || instructions[next_instr_id]->OpBase()->Type() == "depend") { waiter_instr_ids->insert(next_instr_id); return; @@ -490,6 +491,7 @@ void analyse_event_info_for_two_instructions< if (has_data_dependency( instructions[cur_instr_id], instructions[next_instr_id]) || + !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || instructions[next_instr_id]->Name() == "pd_op.depend") { waiter_instr_ids->insert(next_instr_id); return; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index b6b2be142a0dd9..fb3a7f524636a1 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -29,6 +29,8 @@ #include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_context.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" #ifdef PADDLE_WITH_DNNL #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -42,6 +44,7 @@ #include "paddle/fluid/framework/new_executor/instruction/cond_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" +#include "paddle/fluid/framework/new_executor/instruction/while_instruction.h" #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" @@ -236,7 +239,8 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { scope_ = new_scope; for (size_t i = 0; i < value_exe_info_->GetVarList().size(); i++) { const auto& var_name = value_exe_info_->GetNameById(static_cast(i)); - value_exe_info_->ResetVarList(i, new_scope->FindVar(var_name)); + value_exe_info_->ResetVarList(static_cast(i), + new_scope->FindVar(var_name)); } // The index should be assured valid, cause the InterpreterCore may not be // fully built, but was still cached and used. For example, see unit test @@ -386,7 +390,7 @@ Scope* NewIRInterpreter::InnerScope() const { } std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const { - return value_exe_info_->GetValue2VarName().at(value); + return value_exe_info_->GetVarName(value); } void NewIRInterpreter::UpdateSyncOpNum() { @@ -463,7 +467,7 @@ void NewIRInterpreter::UpdateNcclOpNum() { "pd_op.global_gather_grad", "pd_op.distributed_fused_lamb_grad", "pd_op.margin_cross_entropy_grad", - "pd_op.margin_cross_entropy_grad_" + "pd_op.margin_cross_entropy_grad_", "pd_op.sync_batch_norm_grad", "pd_op.data_norm_grad", "pd_op.class_center_sample_grad", @@ -566,9 +570,17 @@ void NewIRInterpreter::BuildInstruction() { } else if (op->dialect()->name() == "cf") { VLOG(6) << "skip process cf dialect op: " << op->name(); continue; - } else if (op->isa()) { - vec_instruction_base_.emplace_back(std::make_unique( - op_idx++, place_, op, value_exe_info_.get())); + } else if (op->dialect()->name() == "pd_op") { + if (op->isa()) { + vec_instruction_base_.emplace_back(std::make_unique( + op_idx++, place_, op, value_exe_info_.get())); + } else if (op->isa()) { + vec_instruction_base_.emplace_back(std::make_unique( + op_idx++, place_, op, scope_, local_scope_, value_exe_info_.get())); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Now only support pd_kernel and cinn dialect.")); + } } else if (op->dialect()->name() == "pd_kernel") { auto op_name = op->attributes() .at("op_name") @@ -591,8 +603,8 @@ void NewIRInterpreter::BuildInstruction() { } #ifdef PADDLE_WITH_CINN } else if (op->dialect()->name() == "cinn_runtime") { - vec_instruction_base_.emplace_back( - std::make_unique(op_idx++, place_, op, scope_)); + vec_instruction_base_.emplace_back(std::make_unique( + op_idx++, place_, op, *(value_exe_info_.get()))); #endif } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -615,7 +627,7 @@ std::string NewIRInterpreter::DebugValueInfo() { PADDLE_ENFORCE((bool)kv.first, platform::errors::PreconditionNotMet( "vlaue(%s) should not be nullptr", kv.second)); - PADDLE_ENFORCE(value_exe_info_->GetVarName2Id().count(kv.second) > 0, + PADDLE_ENFORCE(value_exe_info_->HasVar(kv.second), platform::errors::PreconditionNotMet( "var(%s) should exist in var_name_2_id_", kv.second)); auto* var = InnerScope()->FindVar(kv.second); @@ -624,8 +636,7 @@ std::string NewIRInterpreter::DebugValueInfo() { platform::errors::PreconditionNotMet( "var(%s) should exist in scope (%p)", kv.second, InnerScope())); os << kv.first.impl() << " -> " << kv.second << " -> " - << value_exe_info_->GetVarName2Id().at(kv.second) << " -> " << var - << "\n"; + << value_exe_info_->GetVarId(kv.first) << " -> " << var << "\n"; } return os.str(); } @@ -793,6 +804,18 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) { for (auto& tensor : *tensor_arr) { TensorRecordStream(tensor); } + } else if (var->IsType()) { + TensorRecordStream( + *(var->GetMutable()->mutable_indices())); + TensorRecordStream( + *(var->GetMutable()->mutable_values())); + } else if (var->IsType()) { + TensorRecordStream( + *(var->GetMutable()->mutable_cols())); + TensorRecordStream( + *(var->GetMutable()->mutable_crows())); + TensorRecordStream( + *(var->GetMutable()->mutable_values())); } else if (var->IsType>()) { // do nothing } else { @@ -833,6 +856,7 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) { } void NewIRInterpreter::CalculateLastLiveOps() { + VLOG(4) << "NewIRInterpreter(): " << this << " start CalculateLastLiveOps"; // calculate last_live_ops_ for (size_t op_idx = 0; op_idx < vec_instruction_base_.size(); ++op_idx) { InstructionBase* instr = vec_instruction_base_[op_idx].get(); @@ -858,13 +882,20 @@ void NewIRInterpreter::CalculateLastLiveOps() { gc_check_vars.insert(var_id); } } + VLOG(4) << "get gc check vars for: " << instr->Name(); for (auto var_id : gc_check_vars) { Scope* inner_scope = InnerScope(); paddle::framework::Variable* var = inner_scope->FindVar( value_exe_info_->GetNameById(static_cast(var_id))); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound("Var(id=%d) should not be nullptr.", + static_cast(var_id))); if (var->IsType() || var->IsType() || - var->IsType()) { + var->IsType() || + var->IsType() || + var->IsType()) { last_live_ops_[var_id].insert(op_idx); } else { VLOG(4) << "not clear " @@ -873,6 +904,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { << framework::ToTypeName(var->Type()); } } + VLOG(4) << "update last_live_ops for: " << instr->Name(); } // clear the last_live_ops list for all vars in skip_gc_vars for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) { @@ -882,7 +914,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { VLOG(8) << "Skip gc for var: " << skip_gc_var; } } - VLOG(4) << "calculate last_live_ops_"; + VLOG(4) << "clear the last_live_ops list for all vars in skip_gc_vars"; // shrink, find the downstream op that has no other op in the // downstream list happens before it @@ -923,6 +955,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { last_live_ops_[i] = minumum_last_live_ops; var_ref_count_[i] = static_cast(last_live_ops_[i].size()); } + VLOG(4) << "shrink the last_live_ops list for all vars in skip_gc_vars"; for (auto& dep : *dependecy_count_) { deps_.emplace_back(std::make_shared(dep)); @@ -931,6 +964,7 @@ void NewIRInterpreter::CalculateLastLiveOps() { refs_.emplace_back(std::make_shared( var_ref_count_[i], value_exe_info_->GetVarList()[i])); } + VLOG(4) << "done CalculateLastLiveOps"; } void NewIRInterpreter::ConstructEventForJitInput() { @@ -1384,8 +1418,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { : "kGpuAsync")) << " runs on " << platform::GetCurrentThreadName(); VLOG(4) << place_ << " " - << instr_node->DebugStringEx(scope_, - value_exe_info_->GetValue2VarName()); + << instr_node->DebugStringEx(scope_, value_exe_info_.get()); if (!instr_node->IsArtificial()) { instr_node->Run(); @@ -1407,8 +1440,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) { : "kGpuAsync")) << " runs on " << platform::GetCurrentThreadName(); VLOG(4) << place_ << " " - << instr_node->DebugStringEx(scope_, - value_exe_info_->GetValue2VarName()); + << instr_node->DebugStringEx(scope_, value_exe_info_.get()); CheckGC(instr_node); VLOG(4) << "done CheckGC"; interpreter::LogDeviceMemoryStats(place_); @@ -1483,6 +1515,9 @@ void NewIRInterpreter::SolvePersisableVarNames() { ::pir::Value value = kv.first; const std::string& var_name = kv.second; ::pir::OpResult result = value.dyn_cast<::pir::OpResult>(); + if (!result) { + continue; + } auto* defining_op = result.owner(); if (defining_op->HasAttribute(kAttrIsPersisable)) { auto is_persisables = diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 4142b3fe872f1a..3ae75ffd870088 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -50,6 +50,11 @@ std::shared_ptr ValueExecutionInfo::NewChild(Scope* scope) { std::shared_ptr info = std::make_shared(scope); info->parent_ = this; + info->value_2_var_name_ = this->value_2_var_name_; + info->var_2_var_name_ = this->var_2_var_name_; + info->var_name_2_id_ = this->var_name_2_id_; + info->id_2_var_name_ = this->id_2_var_name_; + info->var_list_ = this->var_list_; return info; } @@ -157,54 +162,15 @@ void ValueExecutionInfo::ResetVarList(int id, Variable* var) { var_list_[id] = var; } -bool ValueExecutionInfo::HasValue(::pir::Value value) const { - return HasValueInternal(value); -} - -bool ValueExecutionInfo::HasLocalValue(::pir::Value value) const { - return HasValueLocally(value); -} - -std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { - return GetVarNameInternal(value); -} - -std::string ValueExecutionInfo::GetVarName(const Variable* var) const { - return GetVarNameInternal(var); -} - -std::string ValueExecutionInfo::GetLocalVarName(::pir::Value value) const { - return GetVarNameLocally(value); -} - -std::string ValueExecutionInfo::GetLocalVarName(const Variable* var) const { - return GetVarNameLocally(var); -} - -int ValueExecutionInfo::GetVarId(::pir::Value value) const { - return GetVarIdInternal(value); -} - -int ValueExecutionInfo::GetVarId(const Variable* var) const { - return GetVarIdInternal(var); -} - -int ValueExecutionInfo::GetLocalVarId(::pir::Value value) const { - return GetVarIdLocally(value); -} - -int ValueExecutionInfo::GetLocalVarId(const Variable* var) const { - return GetVarIdLocally(var); -} - -bool ValueExecutionInfo::HasValueInternal(::pir::Value value) const { - if (HasValueLocally(value)) { +bool ValueExecutionInfo::HasVar(const std::string& var_name) const { + auto it = var_name_2_id_.find(var_name); + if (it != var_name_2_id_.end()) { return true; } - return (parent_ == nullptr) ? false : parent_->HasValueInternal(value); + return false; } -bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { +bool ValueExecutionInfo::HasValue(::pir::Value value) const { auto it = value_2_var_name_.find(value); if (it != value_2_var_name_.end()) { return true; @@ -212,15 +178,7 @@ bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const { return false; } -std::string ValueExecutionInfo::GetVarNameInternal(::pir::Value value) const { - auto name = GetVarNameLocally(value); - if (name != "") { - return name; - } - return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(value); -} - -std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { +std::string ValueExecutionInfo::GetVarName(::pir::Value value) const { auto it = value_2_var_name_.find(value); if (it != value_2_var_name_.end()) { return it->second; @@ -228,15 +186,7 @@ std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const { return ""; } -std::string ValueExecutionInfo::GetVarNameInternal(const Variable* var) const { - auto name = GetVarNameLocally(var); - if (name != "") { - return name; - } - return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(var); -} - -std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { +std::string ValueExecutionInfo::GetVarName(const Variable* var) const { auto it = var_2_var_name_.find(var); if (it != var_2_var_name_.end()) { return it->second; @@ -244,16 +194,8 @@ std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const { return ""; } -int ValueExecutionInfo::GetVarIdInternal(::pir::Value value) const { - auto id = GetVarIdLocally(value); - if (id != -1) { - return id; - } - return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(value); -} - -int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { - auto var_name = GetVarNameLocally(value); +int ValueExecutionInfo::GetVarId(::pir::Value value) const { + auto var_name = GetVarName(value); auto it = var_name_2_id_.find(var_name); if (it != var_name_2_id_.end()) { return it->second; @@ -261,16 +203,8 @@ int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const { return -1; } -int ValueExecutionInfo::GetVarIdInternal(const Variable* var) const { - auto id = GetVarIdLocally(var); - if (id != -1) { - return id; - } - return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(var); -} - -int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const { - auto var_name = GetVarNameLocally(var); +int ValueExecutionInfo::GetVarId(const Variable* var) const { + auto var_name = GetVarName(var); auto it = var_name_2_id_.find(var_name); if (it != var_name_2_id_.end()) { return it->second; @@ -286,8 +220,9 @@ const std::unordered_set SpecialOps = {"pd_op.feed", "builtin.slice", "builtin.split", "pd_op.data", - "pd_op.shadow_output", - "pd_op.if"}; + "builtin.shadow_output", + "pd_op.if", + "pd_op.while"}; Variable* CreateVar(pir::Value value, const std::string& var_name_prefix, @@ -295,8 +230,14 @@ Variable* CreateVar(pir::Value value, ValueExecutionInfo* value_exe_info) { pir::Operation* def_op = value.dyn_cast().owner(); bool is_persisable = false; - if (def_op->isa<::pir::SetParameterOp>()) { + if (def_op->isa<::pir::GetParameterOp>()) { is_persisable = true; + } else if (def_op->HasAttribute(kAttrIsPersisable)) { + is_persisable = def_op->attribute(kAttrIsPersisable) + .dyn_cast() + .AsVector()[value.dyn_cast().index()] + .dyn_cast() + .data(); } Variable* var = nullptr; @@ -472,18 +413,19 @@ void HandleForSpecialOp(pir::Operation* op, value_exe_info->Rename(value, param_name, orig_name); } - - if (op_name == "pd_op.shadow_output") { - VLOG(6) << "Handle for pd_op.shadow_ouptut"; - auto var_name = - op->attributes().at("name").dyn_cast().AsString(); + if (op_name.compare(pir::ShadowOutputOp::name()) == 0) { + VLOG(6) << "Handle for builtin.shadow_ouptut"; + auto var_name = op->attributes() + .at("output_name") + .dyn_cast() + .AsString(); auto value = op->operand_source(0); // change opreand name to param_name auto orig_name = value_exe_info->GetValue2VarName().at(value); - if (value_exe_info->GetScope()->root()->FindVar(var_name) == nullptr) { - const_cast(value_exe_info->GetScope()->root()) + if (value_exe_info->GetScope()->FindVar(var_name) == nullptr) { + const_cast(value_exe_info->GetScope()) ->Rename(orig_name, var_name); } @@ -557,11 +499,19 @@ void HandleForSpecialOp(pir::Operation* op, if (op_name == "pd_op.if") { auto if_op = op->dyn_cast(); for (size_t i = 0; i < if_op->num_results(); ++i) { - // auto true_value = true_yeid_op->operand_source(i); auto if_op_out_value = if_op->result(i); BuildValue(if_op_out_value, var_name_prefix, value_exe_info); } } + + if (op_name == "pd_op.while") { + auto while_op = op->dyn_cast(); + + for (size_t i = 0; i < while_op->num_results(); ++i) { + auto while_op_out_value = while_op->result(i); + BuildValue(while_op_out_value, var_name_prefix, value_exe_info); + } + } } void HandleForInplaceOp(pir::Operation* op, @@ -592,8 +542,7 @@ void HandleForInplaceOp(pir::Operation* op, const std::string& inplace_name = yaml_parser.InplaceName(value_name); pir::Value inplace_value = op->operand_source(yaml_parser.InputName2Id().at(inplace_name)); - std::string var_name = - value_exe_info->GetValue2VarName().at(inplace_value); + std::string var_name = value_exe_info->GetVarName(inplace_value); VLOG(4) << "inplace: " << value_name << " -> " << inplace_name << " (var: " << var_name << ")"; value_exe_info->AddValue2VarName(value, var_name); @@ -602,8 +551,7 @@ void HandleForInplaceOp(pir::Operation* op, pir::Value view_value = op->operand_source(yaml_parser.InputName2Id().at(view_name)); // const std::string& var_name = value_2_var_name->at(view_value); - const std::string& var_name = - value_exe_info->GetValue2VarName().at(view_value); + std::string var_name = value_exe_info->GetVarName(view_value); VLOG(4) << "view: " << value_name << " -> " << view_name << " (var: " << var_name << ")"; value_exe_info->AddValue2VarName(value, var_name); diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h index 87603f2f14b151..ce0484567b64f0 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h @@ -44,9 +44,11 @@ namespace paddle { namespace framework { class CondInstruction; +class WhileInstruction; class ValueExecutionInfo { public: friend class CondInstruction; + friend class WhileInstruction; explicit ValueExecutionInfo(Scope* scope) : scope_(scope) {} @@ -77,49 +79,19 @@ class ValueExecutionInfo { void ResetVarList(int id, Variable* var); - /// Check a value exist in the ValueExecutionInfo or any of its ancestors. - bool HasValue(::pir::Value value) const; + bool HasVar(const std::string& var_name) const; - /// Check a value exist in the ValueExecutionInfo. - bool HasLocalValue(::pir::Value value) const; + bool HasValue(::pir::Value value) const; std::string GetVarName(::pir::Value value) const; std::string GetVarName(const Variable* var) const; - std::string GetLocalVarName(::pir::Value value) const; - - std::string GetLocalVarName(const Variable* var) const; - int GetVarId(::pir::Value value) const; int GetVarId(const Variable* var) const; - int GetLocalVarId(::pir::Value value) const; - - int GetLocalVarId(const Variable* var) const; - private: - bool HasValueInternal(::pir::Value value) const; - - bool HasValueLocally(::pir::Value value) const; - - std::string GetVarNameInternal(::pir::Value value) const; - - std::string GetVarNameLocally(::pir::Value value) const; - - std::string GetVarNameInternal(const Variable* var) const; - - std::string GetVarNameLocally(const Variable* var) const; - - int GetVarIdInternal(::pir::Value value) const; - - int GetVarIdLocally(::pir::Value value) const; - - int GetVarIdInternal(const Variable* var) const; - - int GetVarIdLocally(const Variable* var) const; - std::shared_ptr NewChild(Scope* scope); ValueExecutionInfo* parent_{nullptr}; // not owned @@ -285,7 +257,12 @@ void BuildPhiContext(pir::Operation* op, continue; } - + PADDLE_ENFORCE_NE( + attr_map.find(t), + attr_map.end(), + phi::errors::NotFound("Not found %s in attr_map, it maybe need mapping " + "it in OpTranslator.", + t)); auto& attr_type_name = op_yaml_info.AttrTypeName(t); if (attr_type_name == "paddle::dialect::IntArrayAttribute") { ctx->EmplaceBackAttr( @@ -299,6 +276,8 @@ void BuildPhiContext(pir::Operation* op, ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); } else if (attr_type_name == "pir::FloatAttribute") { ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == "pir::DoubleAttribute") { + ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); } else if (attr_type_name == "pir::BoolAttribute") { ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); } else if (attr_type_name == "pir::StrAttribute") { diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 2e466962c4d318..f1646f50471a49 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -25,6 +25,8 @@ #include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_context.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" #ifdef PADDLE_WITH_DNNL #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -740,7 +742,9 @@ void ProgramInterpreter::Convert( paddle::framework::Variable* var = inner_scope->FindVar( var_scope_.GetNameById(static_cast(var_id))); if (var->IsType() || var->IsType() || - var->IsType()) { + var->IsType() || + var->IsType() || + var->IsType()) { last_live_ops_[var_id].insert(op_idx); } else { VLOG(4) << "not clear " @@ -1018,7 +1022,7 @@ void ProgramInterpreter::RunInstruction(const Instruction& instr_node) { instr_node.RecordEvent(place_); } catch (platform::EnforceNotMet& ex) { framework::InsertCallStackInfo(op->Type(), op->Attrs(), &ex); - exception_holder_.Catch(std::make_exception_ptr(std::move(ex))); + exception_holder_.Catch(std::make_exception_ptr(ex)); } catch (platform::EOFException&) { exception_holder_.Catch(std::current_exception()); } catch (std::exception& ex) { @@ -1305,6 +1309,18 @@ void ProgramInterpreter::RecordStreamForGC(const Instruction& instr) { for (auto& tensor : *tensor_arr) { TensorRecordStream(tensor); } + } else if (var->IsType()) { + TensorRecordStream( + *(var->GetMutable()->mutable_indices())); + TensorRecordStream( + *(var->GetMutable()->mutable_values())); + } else if (var->IsType()) { + TensorRecordStream( + *(var->GetMutable()->mutable_cols())); + TensorRecordStream( + *(var->GetMutable()->mutable_crows())); + TensorRecordStream( + *(var->GetMutable()->mutable_values())); } else if (var->IsType>()) { // do nothing } else { @@ -1331,6 +1347,8 @@ void ProgramInterpreter::CheckGC(const Instruction& instr) { // ignore all persistable var while GC if (var_scope.VarDesc(static_cast(var_id)) && var_scope.VarDesc(static_cast(var_id))->Persistable()) { + VLOG(4) << "Skip persistable var: " + << var_scope_.GetNameById(static_cast(var_id)); continue; } if (is_ready) { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 8810d7f15ac4b0..1846b7c9f0f71b 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -1133,7 +1133,7 @@ void OpDesc::InferShape(const BlockDesc &block) { infer_shape(&ctx); } catch (platform::EnforceNotMet &exception) { framework::AppendErrorOpHint(Type(), &exception); - throw std::move(exception); + throw exception; } catch (...) { std::rethrow_exception(std::current_exception()); } diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index e1bc5be8c64f9e..632a8cbefc63c3 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #include "paddle/utils/flat_hash_map.h" - +#include "paddle/utils/test_macros.h" namespace paddle { namespace framework { @@ -128,7 +128,7 @@ class OpInfo { } }; -class OpInfoMap { +class TEST_API OpInfoMap { public: static OpInfoMap& Instance(); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7a3271a48debc8..17d5f6c4f356a1 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -797,7 +797,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { VLOG(3) << GetExecutionPlace(place) << " " << DebugStringEx(&scope); } catch (platform::EnforceNotMet& exception) { framework::InsertCallStackInfo(Type(), Attrs(), &exception); - throw std::move(exception); + throw exception; } catch (platform::EOFException&) { std::rethrow_exception(std::current_exception()); } catch (std::exception& ex) { @@ -1712,8 +1712,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, VLOG(6) << *kernel_signature_.get(); phi_kernel_name = kernel_signature_->name; - kernel_type_ = std::make_unique( - std::move(InnerGetExpectedKernelType(exe_ctx))); + kernel_type_ = + std::make_unique(InnerGetExpectedKernelType(exe_ctx)); dev_ctx = pool.Get(kernel_type_->place_); // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // But the default library_type is Plain, so we need to modify the @@ -2220,8 +2220,8 @@ phi::KernelKey OperatorWithKernel::ChoosePhiKernel( } VLOG(6) << *kernel_signature_.get(); phi_kernel_name = kernel_signature_->name; - kernel_type_ = std::make_unique( - std::move(InnerGetExpectedKernelType(ctx))); + kernel_type_ = + std::make_unique(InnerGetExpectedKernelType(ctx)); auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); phi_kernel_ = @@ -3249,33 +3249,32 @@ void OperatorWithKernel::BuildPhiKernelContext( // scalar is in the attribute switch (AttrTypeID(attr_iter->second)) { case proto::AttrType::FLOAT: - phi_kernel_context->EmplaceBackAttr(std::move( - phi::Scalar(PADDLE_GET_CONST(float, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(float, attr_iter->second))); break; case proto::AttrType::FLOAT64: - phi_kernel_context->EmplaceBackAttr(std::move( - phi::Scalar(PADDLE_GET_CONST(double, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(double, attr_iter->second))); break; case proto::AttrType::INT: - phi_kernel_context->EmplaceBackAttr(std::move( - phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second))); break; case proto::AttrType::LONG: - phi_kernel_context->EmplaceBackAttr(std::move( - phi::Scalar(PADDLE_GET_CONST(int64_t, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(int64_t, attr_iter->second))); break; case proto::AttrType::STRING: - phi_kernel_context->EmplaceBackAttr(std::move(phi::Scalar( - PADDLE_GET_CONST(std::string, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr(phi::Scalar( + PADDLE_GET_CONST(std::string, attr_iter->second))); break; case proto::AttrType::BOOLEAN: - phi_kernel_context->EmplaceBackAttr(std::move( - phi::Scalar(PADDLE_GET_CONST(bool, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(bool, attr_iter->second))); break; case proto::AttrType::SCALAR: - phi_kernel_context->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST( - paddle::experimental::Scalar, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr(phi::Scalar(PADDLE_GET_CONST( + paddle::experimental::Scalar, attr_iter->second))); break; default: PADDLE_THROW(platform::errors::Unimplemented( @@ -3448,7 +3447,7 @@ void OperatorWithKernel::BuildPhiKernelContext( } break; case phi::AttributeType::STRING: phi_kernel_context->EmplaceBackAttr( - std::move(PADDLE_GET_CONST(std::string, attr_iter->second))); + PADDLE_GET_CONST(std::string, attr_iter->second)); break; case phi::AttributeType::INT64S: switch (AttrTypeID(attr_iter->second)) { diff --git a/paddle/fluid/framework/phi_utils.h b/paddle/fluid/framework/phi_utils.h index 67153a7001ece8..d1eb5558c54541 100644 --- a/paddle/fluid/framework/phi_utils.h +++ b/paddle/fluid/framework/phi_utils.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/init_default_kernel_signature_map.h" #include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" @@ -60,8 +61,6 @@ class KernelArgsNameMaker { virtual const paddle::small_vector& GetAttrsArgsNames() = 0; }; -TEST_API void InitDefaultKernelSignatureMap(); - // TODO(Wilber): support others device context. template struct ConvertToPhiContext { diff --git a/paddle/fluid/framework/program_utils.cc b/paddle/fluid/framework/program_utils.cc index 2d8a35ca00a76f..e8e64b68d767e4 100644 --- a/paddle/fluid/framework/program_utils.cc +++ b/paddle/fluid/framework/program_utils.cc @@ -90,15 +90,15 @@ void ProgramProcessor::GetInputsOutputsInBlock( // NOTE: Here assumes that all variables are input or output of Ops, for (OpDesc *op : current_block.AllOps()) { - for (auto iname : op->InputNames()) { - for (auto in_var_name : op->Input(iname)) { + for (auto const &iname : op->InputNames()) { + for (auto const &in_var_name : op->Input(iname)) { VLOG(3) << "insert inner_inputs_name:" << in_var_name; inner_inputs->insert(in_var_name); } } - for (auto oname : op->OutputNames()) { - for (auto out_var_name : op->Output(oname)) { + for (auto const &oname : op->OutputNames()) { + for (auto const &out_var_name : op->Output(oname)) { VLOG(3) << "insert out_var_name:" << out_var_name; inner_outputs->insert(out_var_name); } @@ -150,7 +150,7 @@ void ProgramProcessor::AddDepToBlockOp(const BlockDesc &block) { VLOG(3) << "sub_outputs.size:" << sub_outputs.size(); auto *op_inputs = op->MutableInputs(); - std::vector *op_input_var_vec; + std::vector *op_input_var_vec = nullptr; VLOG(3) << "op_type:>>>>>>" << op_type; if (op_type.compare("while") == 0) { op_input_var_vec = &((*op_inputs)["kX"]); @@ -163,7 +163,7 @@ void ProgramProcessor::AddDepToBlockOp(const BlockDesc &block) { continue; } - for (auto sub_input : sub_inputs) { + for (auto const &sub_input : sub_inputs) { if (std::find(op_input_var_vec->begin(), op_input_var_vec->end(), sub_input) == op_input_var_vec->end()) @@ -175,7 +175,7 @@ void ProgramProcessor::AddDepToBlockOp(const BlockDesc &block) { auto *op_outputs = op->MutableOutputs(); auto *op_output_var_vec = &((*op_outputs)["kOutputs"]); - for (auto sub_output : sub_outputs) { + for (auto const &sub_output : sub_outputs) { if (std::find(op_output_var_vec->begin(), op_output_var_vec->end(), sub_output) == op_output_var_vec->end()) diff --git a/paddle/fluid/framework/pull_dense_worker.cc b/paddle/fluid/framework/pull_dense_worker.cc index db8506e9ec5c92..f295fa7106dd43 100644 --- a/paddle/fluid/framework/pull_dense_worker.cc +++ b/paddle/fluid/framework/pull_dense_worker.cc @@ -45,7 +45,7 @@ void PullDenseWorker::Initialize(const TrainerDesc& param) { uint64_t tid = static_cast( dwp_param_.program_config(0).pull_dense_table_id(i)); TableParameter table; - for (auto i : param_.dense_table()) { + for (auto const& i : param_.dense_table()) { if (i.table_id() == tid) { table = i; break; diff --git a/paddle/fluid/framework/scope_pool.cc b/paddle/fluid/framework/scope_pool.cc index 1f7aba8e225bde..833848864a785e 100644 --- a/paddle/fluid/framework/scope_pool.cc +++ b/paddle/fluid/framework/scope_pool.cc @@ -29,7 +29,7 @@ void ScopePool::Insert(std::unique_ptr &&s) { } void ScopePool::Remove(Scope *s) { - size_t has_scope; + size_t has_scope = 0; { std::lock_guard guard(mtx_); has_scope = scopes_.erase(s); diff --git a/paddle/fluid/framework/selected_rows_utils.cc b/paddle/fluid/framework/selected_rows_utils.cc index d74e45449226f5..3f72ced811390c 100644 --- a/paddle/fluid/framework/selected_rows_utils.cc +++ b/paddle/fluid/framework/selected_rows_utils.cc @@ -45,7 +45,7 @@ void SerializeToStream(std::ostream& os, void SerializeToStream(std::ostream& os, const phi::SelectedRows& selected_rows) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - const platform::DeviceContext* dev_ctx; + const platform::DeviceContext* dev_ctx = nullptr; auto place = selected_rows.place(); dev_ctx = pool.Get(place); SerializeToStream(os, selected_rows, *dev_ctx); @@ -53,7 +53,7 @@ void SerializeToStream(std::ostream& os, void DeserializeFromStream(std::istream& is, phi::SelectedRows* selected_rows) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - const platform::DeviceContext* dev_ctx; + const platform::DeviceContext* dev_ctx = nullptr; dev_ctx = pool.Get(platform::CPUPlace()); DeserializeFromStream(is, selected_rows, *dev_ctx); } @@ -63,7 +63,7 @@ void DeserializeFromStream(std::istream& is, const platform::DeviceContext& dev_ctx) { { // the 1st field, unit32_t version for SelectedRows - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ(version, 0U, @@ -86,7 +86,7 @@ void DeserializeFromStream(std::istream& is, } { // the 3st field, the height of the SelectedRows - int64_t height; + int64_t height = 0; is.read(reinterpret_cast(&height), sizeof(int64_t)); selected_rows->set_height(height); } diff --git a/paddle/fluid/framework/string_array.cc b/paddle/fluid/framework/string_array.cc index 58c658a67c69eb..07e3f07294fae6 100644 --- a/paddle/fluid/framework/string_array.cc +++ b/paddle/fluid/framework/string_array.cc @@ -81,20 +81,20 @@ void StringMapToStream(std::ostream& os, void StringMapFromStream(std::istream& is, std::unordered_map* data) { // first read the map size - size_t map_size; + size_t map_size = 0; is.read(reinterpret_cast(&map_size), sizeof(map_size)); data->reserve(map_size); // then read the data for (size_t i = 0; i < map_size; ++i) { // read the token - size_t token_length; + size_t token_length = 0; is.read(reinterpret_cast(&token_length), sizeof(token_length)); char* tmp = new char[token_length]; is.read(tmp, token_length); // NOLINT std::string token(tmp, tmp + token_length); delete[] tmp; // read the token_id - int32_t token_id; + int32_t token_id = 0; is.read(reinterpret_cast(&token_id), sizeof(token_id)); data->emplace(token, token_id); diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 16a9065c2eb875..d7cfb4738822af 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -274,7 +274,7 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, TENSOR* dst) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - const platform::DeviceContext* dev_ctx; + const platform::DeviceContext* dev_ctx = nullptr; if (platform::is_gpu_place(dst_place) || platform::is_custom_place(dst_place)) { dev_ctx = pool.Get(dst_place); @@ -585,7 +585,7 @@ void TensorFromStream(std::istream& is, const platform::DeviceContext& dev_ctx, const size_t& seek, const std::vector& shape) { - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ( @@ -598,7 +598,7 @@ void TensorFromStream(std::istream& is, proto::VarType::TensorDesc desc; { // int32_t size // proto buffer - int32_t size; + int32_t size = 0; is.read(reinterpret_cast(&size), sizeof(size)); std::unique_ptr buf(new char[size]); // NOLINT is.read(reinterpret_cast(buf.get()), size); @@ -612,7 +612,7 @@ void TensorFromStream(std::istream& is, size_t seekg = seek * framework::SizeOfType(desc.data_type()); is.seekg(seekg, is.cur); // NOLINT - void* buf; + void* buf = nullptr; phi::CPUContext ctx; size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace()) || @@ -652,7 +652,7 @@ void TensorFromStream(std::istream& is, void TensorFromStream(std::istream& is, phi::DenseTensor* tensor, const platform::DeviceContext& dev_ctx) { - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); PADDLE_ENFORCE_EQ( version, @@ -685,7 +685,7 @@ void TensorFromStream(std::istream& is, dims.reserve(static_cast(desc.dims().size())); std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); tensor->Resize(phi::make_ddim(dims)); - void* buf; + void* buf = nullptr; phi::CPUContext ctx; size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace()) || diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 6f2b8844f52de7..cb872462a1297e 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -605,7 +605,7 @@ void BasicEngine::Execute() { } } catch (platform::EnforceNotMet& exception) { Clear(); - throw std::move(exception); + throw exception; } catch (std::exception& ex) { Clear(); PADDLE_THROW(platform::errors::External("%s", ex.what())); @@ -620,7 +620,7 @@ void BasicEngine::Execute() { } for (auto& pair : inplace_output_grad_var_list_) { - *pair.first = std::move(*pair.second); + *pair.first = *pair.second; } // Step 2: Sum Gradient of This graph diff --git a/paddle/fluid/imperative/data_loader.cc b/paddle/fluid/imperative/data_loader.cc index 3e2e96f1432773..bf09ac38d6d113 100644 --- a/paddle/fluid/imperative/data_loader.cc +++ b/paddle/fluid/imperative/data_loader.cc @@ -128,9 +128,9 @@ void SetLoadProcessSignalHandler() { } void ThrowErrorIfLoadProcessFailed() { - int error; - std::set *pids_set; - pid_t process_pid; + int error = 0; + std::set *pids_set = nullptr; + pid_t process_pid = 0; siginfo_t infop; for (auto &p : load_process_pids) { diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.cc b/paddle/fluid/imperative/jit/program_desc_tracer.cc index 757668f12ddc70..deda1ff572a704 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.cc +++ b/paddle/fluid/imperative/jit/program_desc_tracer.cc @@ -200,7 +200,7 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( } } - op_desc->SetInput(pair.first, std::move(names)); + op_desc->SetInput(pair.first, names); } for (auto &pair : op->Outputs()) { @@ -212,7 +212,7 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( } } - op_desc->SetOutput(pair.first, std::move(names)); + op_desc->SetOutput(pair.first, names); } } diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 3528fc87b6ab1f..3f8c35b6f5e556 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -77,7 +77,7 @@ static framework::RuntimeContext PrepareRuntimeContext( out_ctx.emplace_back(out_var->MutableVar()); } } - return framework::RuntimeContext(std::move(inputs), std::move(outputs)); + return framework::RuntimeContext(inputs, outputs); } template diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 833a13546ccd77..22651eaa1d9e0d 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -907,7 +907,7 @@ void PartialGradTask::RunEachOp(OpBase *op) { } else { for (auto &grad_var : input_pair.second) { if (grad_var) { - bool is_last; + bool is_last = false; new_inputs.emplace_back( ready_grad_vars_.Get(grad_var.get(), op->place(), &is_last)); VLOG(10) << "Got ready grad var " << grad_var->Name() << " " @@ -1031,7 +1031,7 @@ void PartialGradTask::RunEachOp(OpBase *op) { assign_op->SetPlace(op->place()); if (auto grad_pending_node = grad_grad->GetGradNode()) { - assign_node->InsertGradPendingNode(std::move(grad_pending_node)); + assign_node->InsertGradPendingNode(grad_pending_node); } } VLOG(10) << "Pending ops of assign is " diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 502eeb59114d0e..b03aadd4dc6aa2 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -451,7 +451,7 @@ void Reducer::InitializeGroups( .inside_group_index = inside_group_index++, }; } - group.variable_indices_ = std::move(variable_indices_); + group.variable_indices_ = variable_indices_; groups_.emplace_back(std::move(group)); // Debug Message For Reducer VLOG(3) << "The Group[" << group_index << "]:" << groups_.back(); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 6e188b3d21c642..0f992c9b8be309 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -348,7 +348,7 @@ void Tracer::TraceOpImpl(const std::string& type, } } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(type, &exception); - throw std::move(exception); + throw exception; } catch (std::exception& ex) { PADDLE_THROW( platform::errors::Fatal("Operator %s raises an %s exception.\n" diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 008be29fe94fb1..08bd2749ad3993 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -280,6 +280,9 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_optimization_level, TensorRtOptimizationLevel, int); + DECL_ARGUMENT_FIELD(tensorrt_ops_run_float, + TensorRtOpsRunFloat, + std::unordered_set); DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool); DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 9bc016fc62faf7..d3e4ce93ca01e5 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -226,6 +226,9 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("use_inspector", new bool(argument->tensorrt_use_inspector())); pass->Set("inspector_serialize", new bool(argument->tensorrt_inspector_serialize())); + pass->Set("trt_ops_run_float", + new std::unordered_set( + argument->tensorrt_ops_run_float())); pass->Set("use_explicit_quantization", new bool(argument->tensorrt_use_explicit_quantization())); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 7d0d43b8c8d23e..2e74062bedff62 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -379,6 +379,23 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( std::vector origin_outputs_dtype; std::map map_origin_outputs_dtype; + // rename output names in trt_ops_run_float + auto trt_ops_run_float = + Get>("trt_ops_run_float"); + for (auto node : subgraph) { + if (node->NodeType() == Node::Type::kOperation) { + for (auto *x : node->outputs) { + if (std::count(parameters.begin(), parameters.end(), x->Name()) > 0) + continue; + if (trt_ops_run_float.count(x->Name()) > 0) { + trt_ops_run_float.erase(x->Name()); + trt_ops_run_float.insert( + RenameVarBeUnique(x->Name(), std::to_string(x->id()))); + } + } + } + } + // Mark TensorRT output nodes as trt outputs auto mark_output = Get("mark_output"); auto output_tensor_name = @@ -393,7 +410,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( continue; if ((std::count(output_tensor_name.begin(), output_tensor_name.end(), - x->Name()) > 0) || + x->Name()) > 0) && !x->outputs.empty()) { VLOG(3) << "output " << x->Name() << " has been marked"; output_names.insert(x->Name()); @@ -783,6 +800,9 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( inference::Singleton::Global() .Create(engine_key + std::to_string(predictor_id), params); + // support force ops to run in FP32 precision + trt_engine->SetRunFloat(trt_ops_run_float); + if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( Get("model_opt_cache_dir"), engine_key); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index a0d66dc5092981..c3d4c3329016ad 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -482,6 +482,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(trt_engine_memory_sharing_); CP_MEMBER(trt_engine_memory_sharing_identifier_); CP_MEMBER(trt_optimization_level_); + CP_MEMBER(trt_ops_run_float_); // Dlnne related CP_MEMBER(use_dlnne_); CP_MEMBER(dlnne_min_subgraph_size_); @@ -606,7 +607,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // deleted_pass. pass_builder_->ClearPasses(); auto other_passes = other.pass_builder()->AllPasses(); - for (auto pass : other_passes) { + for (auto const &pass : other_passes) { pass_builder_->AppendPass(pass); } } @@ -623,7 +624,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { other_passes.begin(), other_passes.end(), std::inserter(deleted_passes, deleted_passes.begin())); - for (auto ps : deleted_passes) { + for (auto const &ps : deleted_passes) { pass_builder_->DeletePass(ps); } } @@ -1148,7 +1149,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << xpu_config_.quant_post_static_gelu_out_threshold; ss << xpu_config_.quant_post_dynamic_activation_method; ss << xpu_config_.quant_post_dynamic_weight_precision; - for (auto type : xpu_config_.quant_post_dynamic_op_types) ss << type; + for (auto const &type : xpu_config_.quant_post_dynamic_op_types) ss << type; ss << xpu_lite_l3_locked_; ss << xpu_lite_enable_multi_stream_; @@ -1164,11 +1165,11 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << ipu_available_memory_proportion_; ss << ipu_enable_half_partial_; ss << ipu_enable_model_runtime_executor_; - for (auto custom_op : ipu_custom_ops_info_) - for (auto attr : custom_op) ss << attr; + for (auto const &custom_op : ipu_custom_ops_info_) + for (auto const &attr : custom_op) ss << attr; ss << ";"; - for (auto pattern : ipu_custom_patterns_) - for (auto attr : pattern) ss << attr; + for (auto const &pattern : ipu_custom_patterns_) + for (auto const &attr : pattern) ss << attr; ss << ";"; for (auto &op : mixed_black_list_) ss << op.c_str(); for (auto &op : mixed_white_list_) ss << op.c_str(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f30e2c560b57ff..a098bc524f2555 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -217,7 +217,7 @@ bool PaddleTensorToDenseTensor(const PaddleTensor &pt, phi::DenseTensor *t, const platform::Place &place) { framework::DDim ddim = phi::make_ddim(pt.shape); - void *input_ptr; + void *input_ptr = nullptr; if (pt.dtype == PaddleDType::INT64) { input_ptr = t->mutable_data(ddim, place); } else if (pt.dtype == PaddleDType::FLOAT32) { @@ -1147,6 +1147,9 @@ bool AnalysisPredictor::Run(const std::vector &inputs, bool AnalysisPredictor::Run(const std::vector &inputs, std::vector *outputs) { inference::DisplayMemoryInfo(place_, "before run"); + if (private_context_) { + paddle::platform::DeviceContextPool::SetDeviceContexts(&device_contexts_); + } paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads()); #ifdef PADDLE_WITH_DNNL if (config_.use_mkldnn_) MkldnnPreSet(inputs); @@ -1187,19 +1190,16 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } - // All the containers in the scope will be hold in inference, but the - // operators assume that the container will be reset after each batch. - // Here is a bugfix, collect all the container variables, and reset then to a - // bool; the next time, the operator will call MutableData and construct a new - // container again, so that the container will be empty for each batch. - if (sub_scope_) { - tensor_array_batch_cleaner_.CollectNoTensorVars(sub_scope_); - } - tensor_array_batch_cleaner_.ResetNoTensorVars(); + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(sub_scope_); + tensor_array_batch_cleaner_.ResetTensorArray(); // recover the cpu_math_library_num_threads to 1, in order to avoid thread // conflict when integrating it into deployment service. paddle::platform::SetNumThreads(1); + if (private_context_) { + paddle::platform::DeviceContextPool::SetDeviceContexts(nullptr); + } #ifdef PADDLE_WITH_DNNL if (config_.use_mkldnn_) MkldnnPostReset(); #endif @@ -1425,6 +1425,7 @@ void AnalysisPredictor::PrepareArgument() { config_.trt_use_explicit_quantization_); argument_->SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing()); argument_->SetTensorRtOptimizationLevel(config_.trt_optimization_level_); + argument_->SetTensorRtOpsRunFloat(config_.trt_ops_run_float_); } if (config_.dlnne_enabled()) { @@ -1468,7 +1469,7 @@ void AnalysisPredictor::PrepareArgument() { config_.NNAdapter().nnadapter_subgraph_partition_config_path); std::vector buffer_keys; std::vector> buffer_vals; - for (auto it : config_.NNAdapter().nnadapter_model_cache_buffers) { + for (auto const &it : config_.NNAdapter().nnadapter_model_cache_buffers) { buffer_keys.emplace_back(it.first); buffer_vals.emplace_back(it.second); } @@ -1884,7 +1885,7 @@ std::map> AnalysisPredictor::GetInputTensorShape() { std::map> input_shapes; std::vector names = GetInputNames(); - for (std::string name : names) { + for (std::string const &name : names) { auto *var = inference_program_->Block(0).FindVar(name); PADDLE_ENFORCE_NOT_NULL( var, @@ -1943,7 +1944,7 @@ std::map> AnalysisPredictor::GetOutputTensorShape() { std::map> output_shapes; std::vector names = GetOutputNames(); - for (std::string name : names) { + for (std::string const &name : names) { auto *var = inference_program_->Block(0).FindVar(name); PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( @@ -1988,7 +1989,7 @@ AnalysisPredictor::GetOutputTypes() { std::unique_ptr AnalysisPredictor::GetInputTensor( const std::string &name) { - framework::Scope *scope; + framework::Scope *scope = nullptr; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) if (config_.dist_config().use_dist_model()) { scope = scope_.get(); @@ -2039,7 +2040,7 @@ std::unique_ptr AnalysisPredictor::GetInputTensor( std::unique_ptr AnalysisPredictor::GetOutputTensor( const std::string &name) { - framework::Scope *scope; + framework::Scope *scope; // NOLINT #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) if (config_.dist_config().use_dist_model()) { scope = scope_.get(); @@ -2363,7 +2364,7 @@ void AnalysisPredictor::StatisticShapeRangeInfo() { decltype(min_data) max_data, decltype(min_data) opt_data, decltype(shape_info_) shape_data) { - for (auto it : shape_data) { + for (auto const &it : shape_data) { auto name = it.first; auto shapes = it.second; @@ -2954,6 +2955,7 @@ USE_TRT_CONVERTER(cumsum) USE_TRT_CONVERTER(assign) USE_TRT_CONVERTER(unbind) USE_TRT_CONVERTER(flip) +USE_TRT_CONVERTER(share_data) #if IS_TRT_VERSION_GE(8522) USE_TRT_CONVERTER(flash_multihead_matmul) USE_TRT_CONVERTER(cross_multihead_matmul) @@ -3221,6 +3223,13 @@ void InternalUtils::SetTransformerMaskid( #endif } +void InternalUtils::DisableTensorRtHalfOps( + paddle_infer::Config *c, const std::unordered_set &ops) { +#ifdef PADDLE_WITH_CUDA + c->trt_ops_run_float_ = ops; +#endif +} + void InternalUtils::SyncStream(paddle_infer::Predictor *p) { #ifdef PADDLE_WITH_CUDA auto *pred = dynamic_cast(p->predictor_.get()); diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 76b0410cc8e8f4..c3f50fd6f6bb39 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -214,7 +214,7 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, for (size_t i = 0; i < inputs.size(); ++i) { auto &input = feed_tensors_[i]; framework::DDim ddim = phi::make_ddim(inputs[i].shape); - void *input_ptr; + void *input_ptr = nullptr; if (inputs[i].dtype == PaddleDType::INT64) { input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::FLOAT32) { diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 0a0f786d9a04e5..1b604b544b9475 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -157,7 +157,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( // output of ops with unsigned input must be unsigned is_unsigned = true; double min_scale = std::numeric_limits::max(); - for (auto input_var_name : op->Input("X")) { + for (auto const& input_var_name : op->Input("X")) { PADDLE_ENFORCE_NE( scales_.find(input_var_name), scales_.end(), @@ -577,7 +577,7 @@ AnalysisPredictor::MkldnnQuantizer::Histogram( ++hist[bin]; } - return std::make_pair(std::move(hist), std::move(bin_width)); + return std::make_pair(std::move(hist), bin_width); } void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 5299fa4334ae83..4f9982f0a6d406 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -1294,6 +1294,8 @@ struct PD_INFER_DECL AnalysisConfig { bool trt_engine_memory_sharing_{false}; int trt_engine_memory_sharing_identifier_{0}; + std::unordered_set trt_ops_run_float_; + bool use_mkldnn_{false}; std::unordered_set mkldnn_enabled_op_types_; diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index af16aead74129e..3fefba9ef22be8 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include "crypto/cipher.h" @@ -517,6 +518,9 @@ class PD_INFER_DECL InternalUtils { static void SetTransformerMaskid( paddle_infer::Config* c, const std::string& tensorrt_transformer_maskid); + static void DisableTensorRtHalfOps( + paddle_infer::Config* c, const std::unordered_set& ops); + static void SyncStream(paddle_infer::Predictor* pred); static void SyncStream(cudaStream_t stream); template diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 2471c365e29ed9..206b2f5a6a2fdb 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -111,7 +111,8 @@ list( assign_op.cc flip_op.cc quantize_linear_op.cc - dequantize_linear_op.cc) + dequantize_linear_op.cc + share_data_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index b68a703c7edf98..70893a97815943 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -43,45 +43,16 @@ void ConvertConv2d(TensorRTEngine* engine, framework::OpDesc op_desc(op, nullptr); auto* X = engine->GetITensor(op_desc.Input("Input").front()); - bool enable_int8 = op_desc.HasAttr("enable_int8"); - - if (enable_int8) { -#if IS_TRT_VERSION_GE(5000) - float in_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Input_scale")); - engine->SetTensorDynamicRange(X, in_scale); -#endif - } - + std::string filter_var_name = op_desc.Input("Filter").front(); + auto* Y_v = scope.FindVar(filter_var_name); + phi::DenseTensor* Y_t = nullptr; + nvinfer1::ITensor* filter = nullptr; int n_output; int n_input; int filter_h; int filter_w; - std::string filter_var_name = op_desc.Input("Filter").front(); - TensorRTEngine::Weight weight; - if (engine->use_explicit_quantization()) { - auto* filter = engine->GetITensor(filter_var_name); - PADDLE_ENFORCE_NOT_NULL( - filter, - platform::errors::NotFound("Can not find %s ITensor in engine", - filter_var_name)); - auto filter_dims = filter->getDimensions(); - PADDLE_ENFORCE_EQ( - filter_dims.nbDims, - 4UL, - platform::errors::InvalidArgument( - "The conv2d filter's dims size should be 4, but got %d", - filter_dims.nbDims)); - n_output = filter_dims.d[0]; - n_input = filter_dims.d[1]; - filter_h = filter_dims.d[2]; - filter_w = filter_dims.d[3]; - } else { - auto* Y_v = scope.FindVar(filter_var_name); - PADDLE_ENFORCE_NOT_NULL( - Y_v, - platform::errors::NotFound("Can not find %s presistale var in scope.", - filter_var_name)); - auto* Y_t = Y_v->GetMutable(); + if (Y_v) { + Y_t = Y_v->GetMutable(); PADDLE_ENFORCE_EQ( Y_t->dims().size(), 4UL, @@ -92,7 +63,27 @@ void ConvertConv2d(TensorRTEngine* engine, n_input = Y_t->dims()[1]; filter_h = Y_t->dims()[2]; filter_w = Y_t->dims()[3]; - weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t); + } else { + filter = engine->GetITensor(op_desc.Input("Filter").front()); + PADDLE_ENFORCE_EQ( + filter->getDimensions().nbDims, + 4UL, + platform::errors::InvalidArgument( + "The conv2d filter's dims size should be 4, but got %d", + filter->getDimensions().nbDims)); + n_output = filter->getDimensions().d[0]; + n_input = filter->getDimensions().d[1]; + filter_h = filter->getDimensions().d[2]; + filter_w = filter->getDimensions().d[3]; + } + + bool enable_int8 = op_desc.HasAttr("enable_int8"); + + if (enable_int8) { +#if IS_TRT_VERSION_GE(5000) + float in_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Input_scale")); + engine->SetTensorDynamicRange(X, in_scale); +#endif } const int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups")); const std::vector dilations = @@ -133,7 +124,10 @@ void ConvertConv2d(TensorRTEngine* engine, nv_post_paddings.d[0] = paddings[1]; nv_post_paddings.d[1] = paddings[3]; } - + TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, nullptr, 0); + if (Y_v) { + weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t); + } TensorRTEngine::Weight bias; bias.SetDataType(weight.get().type); bias.SetCount(0); @@ -167,7 +161,10 @@ void ConvertConv2d(TensorRTEngine* engine, layer->setStrideNd(nv_strides); layer->setPrePadding(nv_pre_paddings); - if (output_padding.size() > 0) { + + if (!Y_v) layer->setInput(1, *filter); + + if (!output_padding.empty()) { nv_post_paddings.d[0] -= output_padding[0]; nv_post_paddings.d[1] -= output_padding[1]; } @@ -186,11 +183,6 @@ void ConvertConv2d(TensorRTEngine* engine, // set dilations fset_dilation(layer, nv_dilations); - if (engine->use_explicit_quantization()) { - auto* filter_tensor = engine->GetITensor(op_desc.Input("Filter").front()); - layer->setInput(1, *filter_tensor); - } - auto output_name = op_desc.Output("Output").front(); layer->setName((name + " (Output: " + output_name + ")").c_str()); layer->getOutput(0)->setName(output_name.c_str()); @@ -206,6 +198,8 @@ class Conv2dOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + auto output_name = op_desc.Output("Output").front(); ConvertConv2d( engine_, op, @@ -223,6 +217,7 @@ class Conv2dOpConverter : public OpConverter { ksize, weight.get(), bias.get()); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); return layer; }, [](nvinfer1::IConvolutionLayer* layer, nvinfer1::DimsHW& dilations) { @@ -237,6 +232,8 @@ class Deconv2dOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + auto output_name = op_desc.Output("Output").front(); ConvertConv2d( engine_, op, @@ -254,6 +251,7 @@ class Deconv2dOpConverter : public OpConverter { ksize, weight.get(), bias.get()); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); return layer; }, [](nvinfer1::IDeconvolutionLayer* layer, nvinfer1::DimsHW& dilations) { diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 419383ff0a3342..198a164894c0b1 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -162,7 +162,6 @@ class ElementwiseTensorOpConverter : public OpConverter { *(less_layer->getOutput(0)), *(equal_layer->getOutput(0)), nvinfer1::ElementWiseOperation::kOR); - RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); } else if (op_type_ == "greater_equal") { auto* greater_layer = @@ -182,7 +181,6 @@ class ElementwiseTensorOpConverter : public OpConverter { *(greater_layer->getOutput(0)), *(equal_layer->getOutput(0)), nvinfer1::ElementWiseOperation::kOR); - RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); } else if (op_type_ == "mod") { auto* div_layer = @@ -191,17 +189,20 @@ class ElementwiseTensorOpConverter : public OpConverter { *X, *reshape_y_tensor, nvinfer1::ElementWiseOperation::kFLOOR_DIV); + SupportFP32MixPrecision(output_name, op_desc.Type(), div_layer); auto* mul_layer = TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *(div_layer->getOutput(0)), *reshape_y_tensor, nvinfer1::ElementWiseOperation::kPROD); + SupportFP32MixPrecision(output_name, op_desc.Type(), mul_layer); auto* layer = TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *X, *(mul_layer->getOutput(0)), nvinfer1::ElementWiseOperation::kSUB); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); } else { auto op_pair = ops.find(op_type_); @@ -215,6 +216,7 @@ class ElementwiseTensorOpConverter : public OpConverter { auto* layer = TRT_ENGINE_ADD_LAYER( engine_, ElementWise, *X, *reshape_y_tensor, op_pair->second); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); } } @@ -347,6 +349,7 @@ class PowOpConverter : public OpConverter { auto* layer = TRT_ENGINE_ADD_LAYER( engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kPOW); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/convert/matrix_multiply_op.cc b/paddle/fluid/inference/tensorrt/convert/matrix_multiply_op.cc index d985c6232c093e..ebe4c724180d13 100644 --- a/paddle/fluid/inference/tensorrt/convert/matrix_multiply_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/matrix_multiply_op.cc @@ -237,7 +237,7 @@ class MatrixMultiplyOpConverter : public OpConverter { matrix_operation_x, *input2, matrix_operation_y); - + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); if (enable_int8) { if (op_desc.HasAttr("out_threshold") || op_desc.HasAttr("Out")) { engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); @@ -259,6 +259,7 @@ class MatrixMultiplyOpConverter : public OpConverter { *layer->getOutput(0), *reshape_alpha->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); } RreplenishLayerAndOutput( layer, "matrix_multiply_op", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 429bc89f0d90ea..3eb01c0951e275 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -371,6 +371,23 @@ class OpConverter { engine->ClearWeights(); } + void SupportFP32MixPrecision(const std::string& output_name, + const std::string& op_type, + nvinfer1::ILayer* layer) { + if (engine_->OpIsRunFloat(output_name) || engine_->OpIsRunFloat(op_type)) { +#if IS_TRT_VERSION_GE(8210) + VLOG(3) << op_type << "(output: " << output_name << ")" + << " is forced to run in FP32 precision."; + layer->resetPrecision(); + layer->setPrecision(nvinfer1::DataType::kFLOAT); +#else + VLOG(3) + << op_type << "(output: " << output_name << ")" + << ": Set layer precision needs TensorRT version 8.2.1 and after."; +#endif + } + } + nvinfer1::ITensor* Cast(nvinfer1::ITensor* input, nvinfer1::DataType dtype) { auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); layer->setOutputType(0, dtype); diff --git a/paddle/fluid/inference/tensorrt/convert/share_data_op.cc b/paddle/fluid/inference/tensorrt/convert/share_data_op.cc new file mode 100644 index 00000000000000..644eeda8d102f1 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/share_data_op.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class ShareDataOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a share_data op to tensorrt"; + framework::OpDesc op_desc(op, nullptr); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "share_data", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(share_data, ShareDataOpConverter); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index ef9989c9fc9ba0..9fe7b51391153c 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -370,6 +370,13 @@ void TensorRTEngine::FreezeNetwork() { params_.optimization_level); #endif +#if IS_TRT_VERSION_GE(8210) + if (!trt_ops_run_float_.empty()) { + infer_builder_config_->setFlag( + nvinfer1::BuilderFlag::kPREFER_PRECISION_CONSTRAINTS); + } +#endif + #if IS_TRT_VERSION_LT(8000) infer_engine_.reset(infer_builder_->buildEngineWithConfig( *network(), *infer_builder_config_)); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d32666e8ccb5c5..ff35be1c607c7f 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -319,6 +319,14 @@ class TensorRTEngine { return quant_dynamic_range_.count(tensor); } + void SetRunFloat(const std::unordered_set& ops) { + trt_ops_run_float_ = ops; + } + + bool OpIsRunFloat(const std::string& op) const { + return trt_ops_run_float_.count(op) > 0; + } + // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. // so we need to copy the weights from GPU to CPU in our op converter. @@ -593,6 +601,9 @@ class TensorRTEngine { // Used for convert weight into Itensor const framework::Scope* scope_{nullptr}; + // specify run on float to avoid overflow + std::unordered_set trt_ops_run_float_; + #if IS_TRT_VERSION_GE(6000) int binding_num_; infer_ptr infer_builder_config_; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index b44c58379ca732..b9c1ee5bdd8a69 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -324,12 +324,12 @@ struct SimpleOpTypeSetTeller : public Teller { auto* block = desc.Block(); if (block) { auto* filter_var_desc = block->FindVar(desc.Input("Filter")[0]); - if (!filter_var_desc->Persistable() && !use_explicit_quantization) { + if (!filter_var_desc->Persistable()) { #if IS_TRT_VERSION_GE(8600) #else LOG(INFO) << "Trt below 8.6 not support conv2d's filter is a intermedoate " - "tensor in conv2d op, please upgarde your TenroRT."; + "tensor in conv2d op, please upgarde your TensorRT."; return false; #endif } @@ -2918,7 +2918,8 @@ struct SimpleOpTypeSetTeller : public Teller { "assign", "flip", "quantize_linear", - "dequantize_linear"}; + "dequantize_linear", + "share_data"}; std::unordered_set teller_set{ "matrix_multiply", @@ -3086,7 +3087,8 @@ struct SimpleOpTypeSetTeller : public Teller { "assign", "flip", "quantize_linear", - "dequantize_linear"}; + "dequantize_linear", + "share_data"}; }; struct GenericPluginTeller : public Teller { diff --git a/paddle/fluid/inference/utils/io_utils.cc b/paddle/fluid/inference/utils/io_utils.cc index 0ee80e3700b5c9..27de396f597856 100644 --- a/paddle/fluid/inference/utils/io_utils.cc +++ b/paddle/fluid/inference/utils/io_utils.cc @@ -80,21 +80,21 @@ void SerializePDTensorToStream(std::ostream *os, const PaddleTensor &tensor) { void DeserializePDTensorToStream(std::istream &is, PaddleTensor *tensor) { // 1. Version - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); // 2. Name - uint64_t name_bytes; + uint64_t name_bytes = 0; is.read(reinterpret_cast(&name_bytes), sizeof(name_bytes)); std::vector bytes(name_bytes); is.read(bytes.data(), name_bytes); // NOLINT tensor->name = std::string(bytes.data(), name_bytes); // 3. LoD - uint64_t lod_level; + uint64_t lod_level = 0; is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); auto *lod = &(tensor->lod); lod->resize(lod_level); for (uint64_t i = 0; i < lod_level; ++i) { - uint64_t size; + uint64_t size = 0; is.read(reinterpret_cast(&size), sizeof(size)); std::vector tmp(size / sizeof(size_t)); is.read(reinterpret_cast(tmp.data()), @@ -102,13 +102,13 @@ void DeserializePDTensorToStream(std::istream &is, PaddleTensor *tensor) { (*lod)[i] = tmp; } // 4. Shape - size_t dims; + size_t dims = 0; is.read(reinterpret_cast(&dims), sizeof(dims)); tensor->shape.resize(dims); is.read(reinterpret_cast(tensor->shape.data()), sizeof(int) * dims); // NOLINT // 5. Data - uint64_t length; + uint64_t length = 0; is.read(reinterpret_cast(&tensor->dtype), sizeof(tensor->dtype)); is.read(reinterpret_cast(&length), sizeof(length)); tensor->data.Resize(length); @@ -139,10 +139,10 @@ void SerializePDTensorsToStream(std::ostream *os, void DeserializePDTensorsToStream(std::istream &is, std::vector *tensors) { // 1. Version - uint32_t version; + uint32_t version = 0; is.read(reinterpret_cast(&version), sizeof(version)); // 2. Tensors - uint64_t num; + uint64_t num = 0; is.read(reinterpret_cast(&num), sizeof(num)); tensors->resize(num); for (auto &tensor : *tensors) { @@ -240,35 +240,41 @@ void DeserializeShapeRangeInfo( continue; } else { std::vector tmp(info.min_shape_size()); - for (size_t k = 0; k < tmp.size(); ++k) tmp[k] = info.min_shape(k); + for (size_t k = 0; k < tmp.size(); ++k) + tmp[k] = info.min_shape(static_cast(k)); min_shape->insert(std::make_pair(name, tmp)); tmp.resize(info.max_shape_size()); - for (size_t k = 0; k < tmp.size(); ++k) tmp[k] = info.max_shape(k); + for (size_t k = 0; k < tmp.size(); ++k) + tmp[k] = info.max_shape(static_cast(k)); max_shape->insert(std::make_pair(name, tmp)); tmp.resize(info.opt_shape_size()); - for (size_t k = 0; k < tmp.size(); ++k) tmp[k] = info.opt_shape(k); + for (size_t k = 0; k < tmp.size(); ++k) + tmp[k] = info.opt_shape(static_cast(k)); opt_shape->insert(std::make_pair(name, tmp)); } } for (int i = 0; i < shape_range_infos.shape_range_info_size(); ++i) { - auto info = shape_range_infos.shape_range_info(i); + auto info = shape_range_infos.shape_range_info(static_cast(i)); auto name = info.name(); if (min_value->count(name) || max_value->count(name) || opt_value->count(name)) { continue; } else { std::vector tmp(info.min_value_size()); - for (size_t k = 0; k < tmp.size(); ++k) tmp[k] = info.min_value(k); + for (size_t k = 0; k < tmp.size(); ++k) + tmp[k] = info.min_value(static_cast(k)); min_value->insert(std::make_pair(name, tmp)); tmp.resize(info.max_value_size()); - for (size_t k = 0; k < tmp.size(); ++k) tmp[k] = info.max_value(k); + for (size_t k = 0; k < tmp.size(); ++k) + tmp[k] = info.max_value(static_cast(k)); max_value->insert(std::make_pair(name, tmp)); tmp.resize(info.opt_value_size()); - for (size_t k = 0; k < tmp.size(); ++k) tmp[k] = info.opt_value(k); + for (size_t k = 0; k < tmp.size(); ++k) + tmp[k] = info.opt_value(static_cast(k)); opt_value->insert(std::make_pair(name, tmp)); } } diff --git a/paddle/fluid/inference/utils/table_printer.cc b/paddle/fluid/inference/utils/table_printer.cc index 7f192152e052f8..564757b88d69a8 100644 --- a/paddle/fluid/inference/utils/table_printer.cc +++ b/paddle/fluid/inference/utils/table_printer.cc @@ -101,7 +101,8 @@ void TablePrinter::InsertRow(const std::vector& row) { if (line.length() > max_width) max_width = line.length(); } - if (max_width > widths_[i]) widths_[i] = static_cast(max_width); + if (static_cast(max_width) > widths_[i]) + widths_[i] = static_cast(max_width); size_t num_lines = table_row[i].size(); if (num_lines > max_height) max_height = num_lines; @@ -159,13 +160,15 @@ void TablePrinter::CalcLayout() { // If the number of rows required for this record is larger than 1, we // will break that line and put it in multiple lines if (num_rows > 1) { - data_[i][j].erase(data_[i][j].begin() + line_index); + data_[i][j].erase(data_[i][j].begin() + line_index); // NOLINT for (size_t k = 0; k < num_rows; ++k) { size_t start = - std::min(static_cast(k * shares_[j]), line.length()); - size_t end = std::min(static_cast((k + 1) * shares_[j]), - line.length()); - data_[i][j].insert(data_[i][j].begin() + line_index + k, + std::min(static_cast(k * shares_[j]), // NOLINT + line.length()); + size_t end = + std::min(static_cast((k + 1) * shares_[j]), // NOLINT + line.length()); + data_[i][j].insert(data_[i][j].begin() + line_index + k, // NOLINT line.substr(start, end - start)); } @@ -173,8 +176,8 @@ void TablePrinter::CalcLayout() { line_index += num_rows - 1; } - if (heights_[i] < (num_rows - 1 + data_[i][j].size())) - heights_[i] += num_rows - 1; + if (heights_[i] < static_cast(num_rows - 1 + data_[i][j].size())) + heights_[i] += static_cast(num_rows - 1); } } } @@ -182,8 +185,8 @@ void TablePrinter::CalcLayout() { void TablePrinter::AddRowDivider(std::stringstream& ss) { ss << "+"; - for (auto share : shares_) { - for (size_t j = 0; j < share + 2; ++j) ss << "-"; + for (float share : shares_) { + for (float j = 0; j < share + 2; ++j) ss << "-"; ss << "+"; } ss << "\n"; @@ -191,15 +194,16 @@ void TablePrinter::AddRowDivider(std::stringstream& ss) { void TablePrinter::AddRow(std::stringstream& ss, size_t row_idx) { auto row = data_[row_idx]; - size_t max_height = heights_[row_idx]; + size_t max_height = static_cast(heights_[row_idx]); for (size_t h = 0; h < max_height; ++h) { ss << "|" << std::left; for (size_t i = 0; i < row.size(); ++i) { if (h < row[i].size()) { - ss << " " << std::setw(shares_[i]) << row[i][h] << " |"; + ss << " " << std::setw(static_cast(shares_[i])) << row[i][h] + << " |"; } else { - ss << " " << std::setw(shares_[i]) << " " + ss << " " << std::setw(static_cast(shares_[i])) << " " << " |"; } } diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 65ed57ebc9be15..a2910ed51b6751 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -259,6 +259,9 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, continue; } VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]); + IR_ENFORCE(var != nullptr, + "Can't find var recursively from current block."); + if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR) { need_inputs_sig.emplace_back("dense"); } else if (var->GetType() == @@ -280,7 +283,7 @@ pir::OpInfo OpTranscriber::LoopkUpOpInfo(pir::IrContext* ctx, if (need_inputs_sig.size() != sig.inputs.size()) { continue; } - size_t i; + size_t i = 0; for (i = 0; i < need_inputs_sig.size(); ++i) { if (need_inputs_sig[i] == "") { continue; @@ -677,10 +680,12 @@ void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx, pir::OpResult value = operation->result(idx_in_op); bool generated_by_vector = value.type().isa(); - (*param_map)[arg_name] = VariableDefiningInfo( - value, - generated_by_vector, - static_cast(generated_by_vector ? idx_in_vec : -1)); + param_map->PushValue( + arg_name, + VariableDefiningInfo( + value, + generated_by_vector, + static_cast(generated_by_vector ? idx_in_vec : -1))); } } @@ -816,7 +821,7 @@ struct AssignValueOpTranscriber : public OpTranscriber { std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) = op_info_concept->get_op_info_(); std::unordered_map attr_info_maps; - for (auto info : attr_infos) { + for (auto const& info : attr_infos) { attr_info_maps.insert({info.name, info}); } @@ -1171,7 +1176,7 @@ struct ShadowOutputOpTranscriber : public OpTranscriber { TranslationContext* param_map, const OpDesc& op_desc, pir::Block* block) override { - auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); std::vector op_inputs; auto legacy_input_vars = op_desc.Input("x", true); @@ -1186,7 +1191,7 @@ struct ShadowOutputOpTranscriber : public OpTranscriber { op_inputs.push_back(defining_info.value); pir::AttributeMap attribute_map = { - {"parameter_name", + {"output_name", pir::StrAttribute::get(ctx, op_desc.GetAttrIfExists("name"))}, }; @@ -1281,7 +1286,7 @@ struct FillConstant2FullTranscriber : public OpTranscriber { {"dtype", paddle::dialect::DataTypeAttribute::get( ctx, - paddle::dialect::VarTypeToDataType( + paddle::translator::VarTypeToDataType( static_cast(dtype)))}}; int place_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place_type")); @@ -1388,7 +1393,7 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { {"dtype", paddle::dialect::DataTypeAttribute::get( ctx, - paddle::dialect::VarTypeToDataType( + paddle::translator::VarTypeToDataType( static_cast(dtype)))}}; return attribute_map; } @@ -1433,11 +1438,11 @@ pir::OpResult TranslateNumClassesForOneHot( auto var_name = legacy_vars[0]; IR_ENFORCE(legacy_vars.size() == 1, "depth_tensor input of one hot MUST be a tensor"); - auto defining_info = param_map->find(legacy_vars[0]); - IR_ENFORCE(defining_info != param_map->end(), + IR_ENFORCE(param_map->count(legacy_vars[0]), "%s should be existed in one_hot_v2 as input depth_tensor.", legacy_vars[0]); - return defining_info->second.value; + auto defining_info = param_map->at(legacy_vars[0]); + return defining_info.value.dyn_cast(); } auto& attribute_translator = AttributeTranslator::instance(); @@ -1527,7 +1532,7 @@ struct ElementwiseTranscriber : public OpTranscriber { ctx, param_map, block, x_defining_info, x_name); x_defining_info = param_map->at(x_name); } - pir::OpResult x_value = x_defining_info.value; + pir::OpResult x_value = x_defining_info.value.dyn_cast(); IR_ENFORCE(x_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), @@ -1558,7 +1563,7 @@ struct ElementwiseTranscriber : public OpTranscriber { ctx, param_map, block, y_defining_info, y_name); y_defining_info = param_map->at(y_name); } - pir::OpResult y_value = y_defining_info.value; + pir::OpResult y_value = y_defining_info.value.dyn_cast(); IR_ENFORCE(y_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), @@ -1577,8 +1582,7 @@ struct ElementwiseTranscriber : public OpTranscriber { axis += static_cast(x_shape.size()); } - int append_size = - static_cast(x_shape.size() - axis - 1 - y_shape.size()); + int append_size = static_cast(x_shape.size() - axis - y_shape.size()); if (append_size < 0) { // which means x.rank <= y.rank, mostly // x.rank=y.rank return {x_value, y_value}; @@ -1593,7 +1597,7 @@ struct ElementwiseTranscriber : public OpTranscriber { pir::OpResult y_new; if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) { std::vector y_new_shape(y_shape); - for (int i = 0; i <= append_size; i++) { + for (int i = 0; i < append_size; i++) { y_new_shape.push_back(1); } dialect::ReshapeOp reshape_op = @@ -1605,7 +1609,7 @@ struct ElementwiseTranscriber : public OpTranscriber { auto shape_op = builder.Build(y_value); auto append_shape_op = builder.Build( std::vector(append_size, 1), - phi::DataType::INT64, + phi::DataType::INT32, phi::CPUPlace()); auto y_true_shape_op = builder.Build( std::vector{shape_op.out(), append_shape_op.out()}); @@ -1622,7 +1626,10 @@ struct ElementwiseTranscriber : public OpTranscriber { struct GradAddOpTranscriber : public ElementwiseTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { - const std::string& target_op_name = "pd_op.add"; + std::string target_op_name = "pd_op.add"; + if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { + target_op_name += "_"; + } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW( @@ -1675,7 +1682,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber { op_desc.Type(), y_name); auto y_defining_info = param_map->at(y_name); - pir::OpResult y_value = y_defining_info.value; + pir::OpResult y_value = y_defining_info.value.dyn_cast(); IR_ENFORCE(y_value, "Expected op[%s]'s input %s is not null", op_desc.Type(), @@ -1693,8 +1700,8 @@ struct ElementwiseGradTranscriber : public OpTranscriber { pir::OpResult value = operation->result(idx_in_op); pir::Builder builder(ctx, operation->GetParent()); auto reshape_op = builder.Build(value, y_shape); - (*param_map)[y_grad_var_name] = - VariableDefiningInfo(reshape_op.out(), false, -1); + param_map->PushValue(y_grad_var_name, + VariableDefiningInfo(reshape_op.out(), false, -1)); } }; @@ -1766,7 +1773,7 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { ctx, param_map, block, defining_info, var_name); defining_info = param_map->at(var_name).value; } - return defining_info.value; + return defining_info.value.dyn_cast(); }; } }; @@ -1861,9 +1868,24 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber { auto output_var = output_vars[0]; auto fused_feedforward_op = operation->dyn_cast(); - (*param_map)[output_var] = - VariableDefiningInfo{fused_feedforward_op.out()}; + param_map->PushValue(output_var, + VariableDefiningInfo{fused_feedforward_op.out()}); + } + } +}; + +struct ShareBufferOpTranscriber : public OpTranscriber { + pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, + const OpDesc& op_desc) override { + std::string target_op_name = dialect::ShareDataOp::name(); + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW( + "Op share_buffer should have corresponding OpInfo " + "pd_op.share_data"); } + + return op_info; } }; @@ -1890,6 +1912,7 @@ OpTranslator::OpTranslator() { special_handlers["reduce_any"] = ReduceOpTranscriber(); special_handlers["rnn"] = RnnOpTranscriber(); special_handlers["shadow_output"] = ShadowOutputOpTranscriber(); + special_handlers["share_buffer"] = ShareBufferOpTranscriber(); special_handlers["set_value"] = LegacySetValueDispatcher(); special_handlers["set_value_grad"] = SetValueGradOpTranscriber(); special_handlers["split"] = SplitOpTranscriber(); diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 313a78da1aab95..11c2743117586b 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -44,6 +44,10 @@ using ProgramDesc = ::paddle::framework::ProgramDesc; using BlockDesc = ::paddle::framework::BlockDesc; using VarDesc = ::paddle::framework::VarDesc; +using TCKey = TranslationContext::Key; +using TCValue = TranslationContext::Value; +using TCContainer = TranslationContext::Container; + const std::unordered_set ProgramTranslator::no_cast_var_names = { "feed", "fetch", @@ -51,25 +55,42 @@ const std::unordered_set ProgramTranslator::no_cast_var_names = { const std::unordered_set ProgramTranslator::unsupported_ops = { "conditional_block_grad", - "while", "while_grad", }; static std::vector GetCondOpIds(const BlockDesc& src_block, uint64_t first_id) { std::vector op_list = {first_id}; - if (src_block.Op(first_id + 1)->Type() == "logical_not") { + if (((first_id + 1) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 1))->Type() == "logical_not")) { op_list.emplace_back(first_id + 1); } - if (src_block.Op(first_id + 2)->Type() == "conditional_block") { + if (((first_id + 2) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 2))->Type() == + "conditional_block")) { op_list.emplace_back(first_id + 2); } - if (src_block.Op(first_id + 3)->Type() == "cast") { + if (((first_id + 3) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 3))->Type() == "cast")) { op_list.emplace_back(first_id + 3); } - size_t output_size = src_block.Op(first_id)->Output("Out").size(); + // Note(zhangbo): Some output variables are input, without select_input op. + std::vector output_names = + src_block.Op(static_cast(first_id))->Output("Out"); + std::vector input_names = + src_block.Op(static_cast(first_id))->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + size_t output_size = diffs.size(); for (size_t i = 0; i < output_size; i++) { - if (src_block.Op(first_id + 4 + i)->Type() == "select_input") { + if (((first_id + 4 + i) < src_block.OpSize()) && + (src_block.Op(static_cast(first_id + 4 + i))->Type() == + "select_input")) { op_list.emplace_back(first_id + 4 + i); } } @@ -80,7 +101,7 @@ ConditionBlockCombination::ConditionBlockCombination( const ::paddle::framework::BlockDesc& src_block, const std::vector& op_ids) { for (auto op_id : op_ids) { - op_list_.emplace_back(src_block.Op(op_id)); + op_list_.emplace_back(src_block.Op(static_cast(op_id))); } PADDLE_ENFORCE(Verify(op_list_), platform::errors::NotFound( @@ -94,7 +115,16 @@ const std::string& ConditionBlockCombination::CondVarName() const { } size_t ConditionBlockCombination::OutputSize() const { - return op_list_[0]->Output("Out").size(); + std::vector output_names = op_list_[0]->Output("Out"); + std::vector input_names = op_list_[0]->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + return diffs.size(); } std::vector<::paddle::framework::VarDesc*> @@ -109,23 +139,41 @@ ConditionBlockCombination::OutputVars() const { return outputs; } -const std::vector& -ConditionBlockCombination::TrueBlockOutputVarNames() const { - return op_list_[0]->Output("Out"); -} - -int ConditionBlockCombination::TrueBlockId() const { - return op_list_[0]->GetBlockAttrId("sub_block"); +std::vector ConditionBlockCombination::TrueBlockOutputVarNames() + const { + std::vector output_names = op_list_[0]->Output("Out"); + std::vector input_names = op_list_[0]->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + return diffs; } std::vector ConditionBlockCombination::FalseBlockOutputVarNames() const { if (op_list_.size() > 1) { - return op_list_[2]->Output("Out"); + std::vector output_names = op_list_[2]->Output("Out"); + std::vector input_names = op_list_[2]->Input("Input"); + std::vector diffs(output_names.size()); + auto iter = std::set_difference(output_names.begin(), + output_names.end(), + input_names.begin(), + input_names.end(), + diffs.begin()); + diffs.resize(iter - diffs.begin()); + return diffs; } return {""}; } +int ConditionBlockCombination::TrueBlockId() const { + return op_list_[0]->GetBlockAttrId("sub_block"); +} + int ConditionBlockCombination::FalseBlockId() const { if (op_list_.size() > 1) { return op_list_[2]->GetBlockAttrId("sub_block"); @@ -140,9 +188,6 @@ bool ConditionBlockCombination::Verify( if (op_list[id]->Type() != "conditional_block") { return false; } - if (op_list.size() == 1 && op_list[id]->Output("Out").size() != 0) { - return false; - } } else if (id == 1) { if (op_list[id]->Type() != "logical_not") { return false; @@ -176,6 +221,55 @@ bool ConditionBlockCombination::Verify( return true; } +const TCValue& TranslationContext::operator[](const TCKey& key) const { + return at(key); +} + +const TCValue& TranslationContext::at(const TCKey& key) const { + auto it = container_.find(key); + if (it == container_.end() && parent_) { + return parent_->at(key); + } + PADDLE_ENFORCE_NE(it, + container_.end(), + platform::errors::InvalidArgument( + "param %s should exists in TranslationContext", key)); + const auto& values = it->second; + PADDLE_ENFORCE_NE( + values.size(), + 0, + platform::errors::InvalidArgument( + "param %s should have size > 0, but get:%d", key, values.size())); + return values.back(); +} + +size_t TranslationContext::count(const TCKey& key) const { + auto it = container_.find(key); + if (it == container_.end()) { + if (parent_) return parent_->count(key); + return 0u; + } + const auto& values = it->second; + PADDLE_ENFORCE_NE( + values.size(), + 0u, + platform::errors::InvalidArgument( + "param %s should have size > 0, but get:%d", key, values.size())); + return values.size(); +} + +void TranslationContext::PushValue(const Key& key, const Value& value) { + container_[key].push_back(value); +} +void TranslationContext::PopValue(const Key& key) { + container_[key].pop_back(); +} + +TranslationContext* TranslationContext::CreateInnerContext() { + sons_.emplace_back(std::make_unique(this)); + return sons_.back().get(); +} + ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, pir::Program* program) : legacy_program_(legacy_program), program_(program) { @@ -189,6 +283,7 @@ void ProgramTranslator::Translate() { TranslateBlock(legacy_program_->Block(0), 0, legacy_program_->Block(0).OpSize(), + ¶m_map_, program_->block()); SetParameterFromSingleBlock(legacy_program_->Block(0)); @@ -204,11 +299,14 @@ void ProgramTranslator::Translate() { } } -void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, - uint64_t start_id, - uint64_t end_id, - pir::Block* dest_block, - bool for_cond_block) { +void ProgramTranslator::TranslateBlock( + const BlockDesc& src_block, + uint64_t start_id, + uint64_t end_id, + TranslationContext* translation_ctx, + pir::Block* dest_block, + bool for_cond_block, + std::vector skip_cond_assign) { VLOG(8) << "=============>start to translate a block"; PADDLE_ENFORCE( (src_block.OpSize() >= end_id) && (start_id <= end_id), @@ -220,11 +318,13 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, src_block.OpSize())); std::unordered_map translate_completed; + std::vector assign_inputs; for (uint64_t op_id = start_id; op_id < end_id; op_id++) { if (translate_completed.count(op_id) && translate_completed.at(op_id)) { continue; } - auto op = src_block.Op(op_id); + + auto op = src_block.Op(static_cast(op_id)); VLOG(8) << "=============>start to translate a op: " << op->Type(); PADDLE_ENFORCE_EQ(unsupported_ops.count(op->Type()), @@ -236,27 +336,33 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, std::vector cond_op_list = {op}; std::vector cond_op_ids = GetCondOpIds(src_block, op_id); ConditionBlockCombination cond_op_combination(src_block, cond_op_ids); - pir::Operation* if_op = - TranslateCondIfOperation(cond_op_combination, dest_block); + pir::Operation* if_op = TranslateCondIfOperation( + cond_op_combination, translation_ctx, dest_block); for (auto cond_id : cond_op_ids) { translate_completed[cond_id] = true; } VLOG(10) << "[op translated][conditional_block]" << if_op; + } else if (op->Type() == "while") { + TranslateWhileOperation(op, translation_ctx, dest_block); } else { - TranslateGeneralOperation(op, dest_block); - translate_completed[op_id] = true; + if (for_cond_block && op->Type() == "assign" && + std::count(skip_cond_assign.begin(), + skip_cond_assign.end(), + op->Output("Out")[0])) { + assign_inputs.push_back(op->Input("X")[0]); + translate_completed[op_id] = true; + } else { + TranslateGeneralOperation(op, translation_ctx, dest_block); + translate_completed[op_id] = true; + } } } // NOTE(zhangbo): If conditional_block operator has output, the cf.yeild // operator needs to be inserted if (for_cond_block) { std::vector yeild_inputs; - for (size_t id = end_id; id < src_block.OpSize(); id++) { - PADDLE_ENFORCE( - src_block.Op(id)->Type() == "assign", - "The operator at the end of the sub block needs to be assign"); - yeild_inputs.emplace_back( - param_map_[src_block.Op(id)->Input("X")[0]].value); + for (size_t id = 0; id < assign_inputs.size(); id++) { + yeild_inputs.emplace_back((*translation_ctx)[assign_inputs[id]].value); } pir::AttributeMap attribute_map; auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); @@ -267,11 +373,13 @@ void ProgramTranslator::TranslateBlock(const BlockDesc& src_block, } pir::Operation* ProgramTranslator::TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, pir::Block* dest_block) { + const ConditionBlockCombination& cond_ops, + TranslationContext* translation_ctx, + pir::Block* dest_block) { auto& type_translator = TypeTranslator::instance(); auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::IfOp::name()); std::vector op_inputs = { - param_map_[cond_ops.CondVarName()].value}; + (*translation_ctx)[cond_ops.CondVarName()].value}; // NOTE(zhangbo): Now paddle::dialect::IfOp has 0 attribute pir::AttributeMap attribute_map; @@ -291,8 +399,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( op_inputs, attribute_map, op_output_types, op_info, 2); for (size_t i = 0; i < output_vardescs.size(); i++) { - param_map_[output_vardescs[i]->Name()] = - VariableDefiningInfo(operation->result(i)); + translation_ctx->PushValue(output_vardescs[i]->Name(), + VariableDefiningInfo(operation->result(i))); } dest_block->push_back(operation); @@ -303,11 +411,16 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( legacy_program_->Block(cond_ops.TrueBlockId()); pir::Region& true_region = operation->region(0); if (true_region.empty()) true_region.emplace_back(); + + auto* true_block_context = translation_ctx->CreateInnerContext(); + TranslateBlock(true_sub_block, 0, - true_sub_block.OpSize() - cond_ops.OutputSize(), + true_sub_block.OpSize(), + true_block_context, true_region.front(), - true); + true, + cond_ops.TrueBlockOutputVarNames()); } VLOG(4) << "[general op][conditional_block] IfOp true block translate end."; @@ -316,28 +429,105 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( legacy_program_->Block(cond_ops.FalseBlockId()); pir::Region& false_region = operation->region(1); if (false_region.empty()) false_region.emplace_back(); + auto* false_block_context = translation_ctx->CreateInnerContext(); TranslateBlock(false_sub_block, 0, - false_sub_block.OpSize() - cond_ops.OutputSize(), + false_sub_block.OpSize(), + false_block_context, false_region.front(), - true); + true, + cond_ops.FalseBlockOutputVarNames()); } VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; + + operation->Verify(); VLOG(4) << "[general op][conditional_block] IfOp translate end."; return operation; } -void ProgramTranslator::TranslateGeneralOperation(const OpDesc* src_op, - pir::Block* dest_block) { +void ProgramTranslator::TranslateWhileOperation( + const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dest_block) { + VLOG(8) << "=============>Start to translate while op:" << op; + auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); + int index = static_cast(sub_block.OpSize()) - 1; + std::vector> loop_vars_reverse; + while (index >= 0) { + auto sub_op = sub_block.Op(index); + if (sub_op->Type() == "assign" && + translation_ctx->count(sub_op->Output("Out")[0]) > 0) { + loop_vars_reverse.emplace_back(sub_op->Output("Out")[0], + sub_op->Input("X")[0]); + --index; + } else { + break; + } + } + PADDLE_ENFORCE(!loop_vars_reverse.empty(), + platform::errors::PreconditionNotMet( + "While op must has condition value input")); + PADDLE_ENFORCE(loop_vars_reverse.front().first == op->Input("Condition")[0], + platform::errors::PreconditionNotMet( + "The last op in sub_block of While op must used to assign " + "condition var")); + auto op_info = ctx_->GetRegisteredOpInfo(paddle::dialect::WhileOp::name()); + std::vector op_inputs{ + translation_ctx->at(loop_vars_reverse[0].first).value}; + std::vector op_outputs_type; + auto body_block = new pir::Block(); + std::vector param_status; + for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { + auto& name = loop_vars_reverse[idx].first; + auto& tc_value = translation_ctx->at(name); + auto val_type = tc_value.value.type(); + op_inputs.push_back(tc_value.value); + op_outputs_type.push_back(val_type); + param_status.emplace_back(tc_value); + translation_ctx->PushValue(name, body_block->AddArgument(val_type)); + } + pir::Operation* while_op = + pir::Operation::Create(op_inputs, {}, op_outputs_type, op_info, 1); + dest_block->push_back(while_op); + while_op->region(0).push_back(body_block); + TranslateBlock(sub_block, 0, index + 1, translation_ctx, body_block); + + auto yeild_info = ctx_->GetRegisteredOpInfo(pir::YieldOp::name()); + std::vector yeild_inputs{ + translation_ctx->at(loop_vars_reverse[0].second).value}; + for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { + auto& name = loop_vars_reverse[idx].second; + yeild_inputs.push_back(translation_ctx->at(name).value); + } + body_block->push_back( + pir::Operation::Create(yeild_inputs, {}, {}, yeild_info)); + + index = 0; + for (size_t idx = loop_vars_reverse.size() - 1u; idx > 0; --idx) { + auto& name = loop_vars_reverse[idx].first; + translation_ctx->PushValue(name, param_status[index++]); + } + auto name_iter = loop_vars_reverse.rbegin(); + for (size_t idx = 0; idx < while_op->num_results(); ++idx) { + translation_ctx->PushValue(name_iter++->first, while_op->result(idx)); + } + while_op->Verify(); + VLOG(8) << "=============>end to translate while op:" << op; +} + +void ProgramTranslator::TranslateGeneralOperation( + const OpDesc* src_op, + TranslationContext* translation_ctx, + pir::Block* dest_block) { auto& op_translator = OpTranslator::instance(); OpTranslateFn& fn = op_translator[src_op->Type()]; if (src_op->Type() == "shadow_output") { - if (!param_map_.count(src_op->Input("x")[0])) { + if (!translation_ctx->count(src_op->Input("x")[0])) { return; } } - pir::Operation* operation = fn(ctx_, ¶m_map_, *src_op, dest_block); - VLOG(10) << "[op translated][special]" << operation << "end"; + pir::Operation* operation = fn(ctx_, translation_ctx, *src_op, dest_block); + VLOG(10) << "[op translated][general]" << operation << "end"; } inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx, @@ -356,7 +546,7 @@ inline pir::Operation* InsertGetParamaterOp(pir::IrContext* ctx, } inline pir::Operation* InsertSetParamaterOp(pir::IrContext* ctx, - pir::OpResult defining_op_result, + pir::Value defining_op_result, const VarDesc* var) { std::string set_parameter_op_name(pir::SetParameterOp::name()); pir::OpInfo op_info = ctx->GetRegisteredOpInfo(set_parameter_op_name); @@ -407,7 +597,7 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { "VarDesc of [%s] can not be nullptr", var_name)); pir::Operation* op = InsertGetParamaterOp(ctx_, var_desc); program_->block()->push_back(op); - param_map_[var_name] = VariableDefiningInfo(op->result(0)); + param_map_.PushValue(var_name, VariableDefiningInfo(op->result(0))); VLOG(10) << "[op translated][get parameter]" << var_name; program_->SetParameter(var_name, nullptr); @@ -425,20 +615,6 @@ void ProgramTranslator::GetParameterForSingleBlock(const BlockDesc& block) { } } -void ProgramTranslator::InsertOperationToSingleBlock(const BlockDesc& block) { - auto& op_translator = OpTranslator::instance(); - for (auto op : block.AllOps()) { - OpTranslateFn& fn = op_translator[op->Type()]; - if (op->Type() == "shadow_output") { - if (!param_map_.count(op->Input("x")[0])) { - continue; - } - } - pir::Operation* operation = fn(ctx_, ¶m_map_, *op, program_->block()); - VLOG(10) << "[op translated][special]" << operation; - } -} - void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { const auto& ops = block.AllOps(); for (auto op_desc = ops.rbegin(); op_desc != ops.rend(); op_desc++) { @@ -459,7 +635,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { need_set_parameter_op &= (param_map_.count(var_name) != 0); need_set_parameter_op &= (!set_input_var_names.count(var_name)); if (need_set_parameter_op) { - pir::OpResult defining_op_result = param_map_[var_name].value; + pir::OpResult defining_op_result = + param_map_[var_name].value.dyn_cast(); if (!defining_op_result) { continue; } @@ -470,7 +647,8 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { program_->block(), param_map_[var_name], var_name); - defining_op_result = param_map_.at(var_name).value; + defining_op_result = + param_map_.at(var_name).value.dyn_cast(); } pir::Operation* op = InsertSetParamaterOp( @@ -501,38 +679,37 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( const BlockDesc& block) { // Currently we set stop gradient for operation that generated a value // connected with VarDesc - for (const auto& [var_name, value_info] : param_map_) { + for (const auto& [var_name, value_list] : param_map_) { if (no_cast_var_names.count(var_name) != 0) continue; VLOG(10) << "[op translated][stop gradient]" << var_name; VarDesc* var = block.FindVarRecursive(var_name); if (var == nullptr) { continue; } - pir::OpResult value = value_info.value; - if (!value) { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Value of [%s] can not ber None", var_name)); - } - auto* defining_op = value.owner(); - PADDLE_ENFORCE_NOT_NULL( - defining_op, - phi::errors::PreconditionNotMet( - "Defining operator of [%s] can not be nullptr", var_name)); - VLOG(8) << "[op translated][stop gradient]" << var_name - << " from: " << defining_op->name(); - std::vector stop_gradients; - if (defining_op->HasAttribute(kAttrStopGradients)) { - stop_gradients = defining_op->attribute(kAttrStopGradients) - .dyn_cast() - .AsVector(); - } else { - stop_gradients = std::vector( - defining_op->num_results(), pir::BoolAttribute::get(ctx_, false)); + for (const auto& value_info : value_list) { + pir::OpResult value = value_info.value.dyn_cast(); + if (!value) continue; + auto* defining_op = value.owner(); + PADDLE_ENFORCE_NOT_NULL( + defining_op, + phi::errors::PreconditionNotMet( + "Defining operator of [%s] can not be nullptr", var_name)); + VLOG(8) << "[op translated][stop gradient]" << var_name + << " from: " << defining_op->name(); + std::vector stop_gradients; + if (defining_op->HasAttribute(kAttrStopGradients)) { + stop_gradients = defining_op->attribute(kAttrStopGradients) + .dyn_cast() + .AsVector(); + } else { + stop_gradients = std::vector( + defining_op->num_results(), pir::BoolAttribute::get(ctx_, false)); + } + stop_gradients[value.index()] = + pir::BoolAttribute::get(ctx_, var->StopGradient()); + defining_op->set_attribute( + kAttrStopGradients, pir::ArrayAttribute::get(ctx_, stop_gradients)); } - stop_gradients[value.index()] = - pir::BoolAttribute::get(ctx_, var->StopGradient()); - defining_op->set_attribute(kAttrStopGradients, - pir::ArrayAttribute::get(ctx_, stop_gradients)); } } @@ -540,39 +717,49 @@ void ProgramTranslator::SetIsPersisableAttributeForAllValue( const BlockDesc& block) { // Currently we set is persisable for operation that generated a value // connected with VarDesc - for (const auto& [var_name, value_info] : param_map_) { + for (const auto& [var_name, value_list] : param_map_) { if (no_cast_var_names.count(var_name) != 0) continue; VLOG(10) << "[op translated][is persisable]" << var_name; VarDesc* var = block.FindVarRecursive(var_name); if (var == nullptr) { continue; } - pir::OpResult value = value_info.value; - if (!value) { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Value of [%s] can not ber None", var_name)); + for (const auto& value_info : value_list) { + pir::OpResult value = value_info.value.dyn_cast(); + if (!value) continue; + auto* defining_op = value.owner(); + PADDLE_ENFORCE_NOT_NULL( + defining_op, + phi::errors::PreconditionNotMet( + "Defining operator of [%s] can not be nullptr", var_name)); + VLOG(8) << "[op translated][is persisable]" << var_name + << " from: " << defining_op->name(); + std::vector is_persisable; + if (defining_op->HasAttribute(kAttrIsPersisable)) { + is_persisable = defining_op->attribute(kAttrIsPersisable) + .dyn_cast() + .AsVector(); + } else { + is_persisable = std::vector( + defining_op->num_results(), pir::BoolAttribute::get(ctx_, false)); + } + is_persisable[value.index()] = + pir::BoolAttribute::get(ctx_, var->Persistable()); + defining_op->set_attribute(kAttrIsPersisable, + pir::ArrayAttribute::get(ctx_, is_persisable)); } - auto* defining_op = value.owner(); - PADDLE_ENFORCE_NOT_NULL( - defining_op, - phi::errors::PreconditionNotMet( - "Defining operator of [%s] can not be nullptr", var_name)); - VLOG(8) << "[op translated][is persisable]" << var_name - << " from: " << defining_op->name(); - std::vector is_persisable; - if (defining_op->HasAttribute(kAttrIsPersisable)) { - is_persisable = defining_op->attribute(kAttrIsPersisable) - .dyn_cast() - .AsVector(); - } else { - is_persisable = std::vector( - defining_op->num_results(), pir::BoolAttribute::get(ctx_, false)); + } +} + +std::unordered_map> +ProgramTranslator::VarDesc2Value() { + std::unordered_map> var_desc_2_value; + for (const auto& [var_name, value_info_list] : param_map_) { + for (const auto& value_info : value_info_list) { + var_desc_2_value[var_name].push_back(value_info.value); } - is_persisable[value.index()] = - pir::BoolAttribute::get(ctx_, var->Persistable()); - defining_op->set_attribute(kAttrIsPersisable, - pir::ArrayAttribute::get(ctx_, is_persisable)); } + return var_desc_2_value; } } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.h b/paddle/fluid/ir_adaptor/translator/program_translator.h index a59f4b34a5adaa..97c7ae1ec86879 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.h +++ b/paddle/fluid/ir_adaptor/translator/program_translator.h @@ -18,6 +18,7 @@ #include #include #include + #include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" @@ -29,7 +30,7 @@ namespace paddle { namespace translator { struct VariableDefiningInfo { - VariableDefiningInfo(pir::OpResult value, + VariableDefiningInfo(pir::Value value, bool generated_by_vector = false, int idx_in_vector = -1) : value(value), @@ -37,7 +38,7 @@ struct VariableDefiningInfo { idx_in_vector(idx_in_vector) {} VariableDefiningInfo() {} - pir::OpResult value; + pir::Value value; bool generated_by_vector = false; // true if target variable is generated by Vector @@ -50,12 +51,12 @@ class ConditionBlockCombination { ConditionBlockCombination(const ::paddle::framework::BlockDesc& src_block, const std::vector& op_ids); const std::string& CondVarName() const; + int TrueBlockId() const; + int FalseBlockId() const; size_t OutputSize() const; std::vector<::paddle::framework::VarDesc*> OutputVars() const; - const std::vector& TrueBlockOutputVarNames() const; - int TrueBlockId() const; + std::vector TrueBlockOutputVarNames() const; std::vector FalseBlockOutputVarNames() const; - int FalseBlockId() const; private: bool Verify(const std::vector<::paddle::framework::OpDesc*>& op_list); @@ -63,8 +64,35 @@ class ConditionBlockCombination { std::vector<::paddle::framework::OpDesc*> op_list_; }; -using TranslationContext = - std::unordered_map; +class TranslationContext { + public: + using Key = std::string; + using Value = VariableDefiningInfo; + using ValueList = std::vector; + using Container = std::unordered_map; + + TranslationContext() {} + explicit TranslationContext(TranslationContext* parent) : parent_(parent) {} + ~TranslationContext() {} + + const Value& operator[](const Key& key) const; + const Value& at(const Key& key) const; + size_t count(const Key& key) + const; // Caution: not exactly same as count in stl library + + void PushValue(const Key& key, const Value& value); + void PopValue(const Key& key); + TranslationContext* CreateInnerContext(); + + Container::const_iterator begin() const { return container_.begin(); } + Container::const_iterator end() const { return container_.end(); } + + private: + Container container_; + TranslationContext* parent_ = nullptr; + std::vector> + sons_; // used to seperate different block +}; class ProgramTranslator { using ProgramDesc = ::paddle::framework::ProgramDesc; @@ -78,6 +106,8 @@ class ProgramTranslator { void Translate(); + std::unordered_map> VarDesc2Value(); + private: const ProgramDesc* legacy_program_; // not owned pir::Program* program_; // not owned @@ -100,18 +130,26 @@ class ProgramTranslator { void TranslateBlock(const BlockDesc& src_block, uint64_t start_id, uint64_t end_id, + TranslationContext* translation_ctx, pir::Block* dest_block, - bool for_cond_block = false); - void TranslateGeneralOperation(const OpDesc* src_op, pir::Block* dest_block); + bool for_cond_block = false, + std::vector skip_cond_assign = {}); + void TranslateGeneralOperation(const OpDesc* src_op, + TranslationContext* translation_ctx, + pir::Block* dest_block); void GetParameterForSingleBlock(const BlockDesc& block); - void InsertOperationToSingleBlock(const BlockDesc& block); void SetParameterFromSingleBlock(const BlockDesc& block); void SetStopGradientAttributeForAllValue(const BlockDesc& block); void SetIsPersisableAttributeForAllValue(const BlockDesc& block); /// Translate methods for control flow ops. pir::Operation* TranslateCondIfOperation( - const ConditionBlockCombination& cond_ops, pir::Block* dest_block); + const ConditionBlockCombination& cond_ops, + TranslationContext* translation_ctx, + pir::Block* dest_block); + void TranslateWhileOperation(const OpDesc* op, + TranslationContext* translation_ctx, + pir::Block* dest_block); }; } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/utils.cc b/paddle/fluid/ir_adaptor/translator/utils.cc index 5ee0c91b5bae5f..7f50115c5c578e 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.cc +++ b/paddle/fluid/ir_adaptor/translator/utils.cc @@ -59,7 +59,7 @@ pir::Operation* InsertSliceOperationForTarget( op_info); block->push_back(operation); pir::OpResult target_op_result = operation->result(0); - (*param_map)[arg_name] = VariableDefiningInfo(target_op_result); + param_map->PushValue(arg_name, VariableDefiningInfo(target_op_result)); return operation; } diff --git a/paddle/fluid/ir_adaptor/translator/utils.h b/paddle/fluid/ir_adaptor/translator/utils.h index 8745ee2ac0d7bf..a4765940d0a78a 100644 --- a/paddle/fluid/ir_adaptor/translator/utils.h +++ b/paddle/fluid/ir_adaptor/translator/utils.h @@ -17,6 +17,7 @@ #include #include +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/pir/core/ir_context.h" @@ -61,5 +62,43 @@ std::ostream& operator<<(std::ostream& os, std::vector CheckUnregisteredOperation( pir::IrContext* ctx, const framework::ProgramDesc& legacy_program); +inline DataType VarTypeToDataType( + ::paddle::framework::proto::VarType_Type var_type) { + switch (var_type) { + case paddle::framework::proto::VarType_Type::VarType_Type_BOOL: + return DataType::BOOL; + case paddle::framework::proto::VarType_Type::VarType_Type_INT16: + return DataType::INT16; + case paddle::framework::proto::VarType_Type::VarType_Type_INT32: + return DataType::INT32; + case paddle::framework::proto::VarType_Type::VarType_Type_INT64: + return DataType::INT64; + case paddle::framework::proto::VarType_Type::VarType_Type_FP16: + return DataType::FLOAT16; + case paddle::framework::proto::VarType_Type::VarType_Type_FP32: + return DataType::FLOAT32; + case paddle::framework::proto::VarType_Type::VarType_Type_FP64: + return DataType::FLOAT64; + case paddle::framework::proto::VarType_Type::VarType_Type_SIZE_T: + return DataType::UINT64; + case paddle::framework::proto::VarType_Type::VarType_Type_UINT8: + return DataType::UINT8; + case paddle::framework::proto::VarType_Type::VarType_Type_INT8: + return DataType::INT8; + case paddle::framework::proto::VarType_Type::VarType_Type_BF16: + return DataType::BFLOAT16; + case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX64: + return DataType::COMPLEX64; + case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX128: + return DataType::COMPLEX128; + case paddle::framework::proto::VarType_Type::VarType_Type_PSTRING: + return DataType::PSTRING; + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported proto::VarType_Type `%s` when casting it into DataType.", + var_type)); + } +} + } // namespace translator } // namespace paddle diff --git a/paddle/fluid/jit/property.cc b/paddle/fluid/jit/property.cc index 174b3b065f1fac..9b0c50a954624c 100644 --- a/paddle/fluid/jit/property.cc +++ b/paddle/fluid/jit/property.cc @@ -340,7 +340,7 @@ void Property::SetStrings(const std::vector &v) { auto type = proto::ValueProto::STRINGS; auto entry = property_.add_entrys(); entry->set_type(type); - for (auto i : v) { + for (auto const &i : v) { entry->add_strings(i); } VLOG(3) << "Property: set_strings " << v.size(); @@ -352,7 +352,7 @@ void Property::SetStrings(const std::string &name, auto entry = property_.add_entrys(); entry->set_name(name); entry->set_type(type); - for (auto i : v) { + for (auto const &i : v) { entry->add_strings(i); } VLOG(3) << "Property: set_strings " << v[0] << " name: " << name; diff --git a/paddle/fluid/jit/serializer_utils.cc b/paddle/fluid/jit/serializer_utils.cc index 5b58b9d4173129..4fdc07f55ac745 100644 --- a/paddle/fluid/jit/serializer_utils.cc +++ b/paddle/fluid/jit/serializer_utils.cc @@ -79,7 +79,7 @@ const std::vector> PdmodelFilePaths( std::string dir_path = format_path.substr(0, format_path.length() - layer_name.length()); DIR* dir = opendir(dir_path.c_str()); - struct dirent* ptr; + struct dirent* ptr = nullptr; while ((ptr = readdir(dir)) != nullptr) { std::string file_name = ptr->d_name; diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 7add694a04f68f..6af73d8f48958d 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -212,10 +212,6 @@ class AllocatorFacadePrivate { platform::CustomPlace(dev_type, dev_id)); } } - if (FLAGS_use_stream_safe_cuda_allocator) { - WrapStreamSafeCustomDeviceAllocatorForDefault(); - is_stream_safe_cuda_allocator_used_ = true; - } #endif break; } @@ -576,13 +572,14 @@ class AllocatorFacadePrivate { #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE - bool HasCustomDevice(const platform::CustomPlace& place, - phi::stream::stream_t stream) { + bool HasCustomDeviceAllocator(const platform::CustomPlace& place, + phi::stream::stream_t stream) { auto it = custom_device_allocators_.find(place); if (it == custom_device_allocators_.end()) { return false; } - auto& allocator_map = it->second; + const std::map>& + allocator_map = it->second; return allocator_map.find(stream) != allocator_map.end(); } @@ -590,10 +587,15 @@ class AllocatorFacadePrivate { const platform::CustomPlace& place, phi::stream::stream_t stream, bool create_if_not_found = false) { + if (stream == GetDefaultStream(place)) { + VLOG(7) << "Get Allocator by passing in a default stream"; + return GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); + } + /* shared_lock_guard */ { std::shared_lock lock_guard( custom_device_allocator_mutex_); - if (LIKELY(HasCustomDevice(place, stream))) { + if (LIKELY(HasCustomDeviceAllocator(place, stream))) { return custom_device_allocators_[place][stream]; } else { PADDLE_ENFORCE_NE(create_if_not_found, @@ -627,17 +629,11 @@ class AllocatorFacadePrivate { return iter->second; } - void RecordStream(std::shared_ptr allocation, - phi::stream::stream_t stream) { - std::shared_ptr - stream_safe_custom_device_allocation = - std::dynamic_pointer_cast( - allocation); - if (stream_safe_custom_device_allocation != nullptr) { - stream_safe_custom_device_allocation->RecordStream(stream); - } else { - VLOG(6) << "RecordStream for a non-StreamSafeCustomDeviceAllocation"; - } + phi::stream::stream_t GetDefaultStream( + const platform::CustomPlace& place) const { + const std::shared_ptr& allocator = + GetDefaultStreamSafeCustomDeviceAllocator(place); + return allocator->GetDefaultStream(); } void SetDefaultStream(const platform::CustomPlace& place, @@ -662,6 +658,34 @@ class AllocatorFacadePrivate { << ") in " << place; } + void RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream) { + std::shared_ptr + stream_safe_custom_device_allocation = + std::dynamic_pointer_cast( + allocation); + if (stream_safe_custom_device_allocation != nullptr) { + stream_safe_custom_device_allocation->RecordStream(stream); + } else { + VLOG(6) << "RecordStream for a non-StreamSafeCustomDeviceAllocation"; + } + } + + phi::stream::stream_t GetStream( + const std::shared_ptr& allocation) const { + const std::shared_ptr + stream_safe_custom_device_allocation = + std::dynamic_pointer_cast( + allocation); + if (stream_safe_custom_device_allocation != nullptr) { + return stream_safe_custom_device_allocation->GetOwningStream(); + } + + VLOG(6) << "GetStream for a non-StreamSafeCustomDeviceAllocation"; + return static_cast( + platform::DeviceContextPool::Instance().Get(allocation->place())) + ->stream(); + } #endif private: @@ -1108,10 +1132,41 @@ class AllocatorFacadePrivate { allocators_[p] = std::make_shared(p); } - void InitNaiveBestFitCustomDeviceAllocator(platform::CustomPlace p, - phi::stream::stream_t stream) { + std::shared_ptr CreateCustomDeviceAllocator( + platform::CustomPlace p) { + return std::make_shared(p); + } + + void InitStreamSafeCustomDeviceAllocator(platform::CustomPlace p, + phi::stream::stream_t stream) { + PADDLE_ENFORCE_EQ( + strategy_, + AllocatorStrategy::kAutoGrowth, + platform::errors::Unimplemented( + "Only support auto-growth strategey for " + "StreamSafeCustomDeviceAllocator, " + "the allocator strategy %d is unsupported for multi-stream", + static_cast(strategy_))); + if (LIKELY(!HasCustomDeviceAllocator(p, stream))) { + VLOG(8) << "Init StreamSafeCustomDeviceAllocator for stream " << stream + << " in place " << p; + InitAutoGrowthCustomDeviceAllocator(p, stream); + WrapStreamSafeCustomDeviceAllocator(p, stream); + } + } + + void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p, + phi::stream::stream_t stream) { + auto chunk_size = FLAGS_auto_growth_chunk_size_in_mb << 20; + VLOG(4) << "FLAGS_auto_growth_chunk_size_in_mb is " + << FLAGS_auto_growth_chunk_size_in_mb; + + auto custom_allocator = + std::make_shared(p); + auto alignment = phi::DeviceManager::GetMinChunkSize(p); custom_device_allocators_[p][stream] = - std::make_shared(p); + std::make_shared( + custom_allocator, alignment, chunk_size, allow_free_idle_chunk_); } void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p, @@ -1146,20 +1201,6 @@ class AllocatorFacadePrivate { } } - void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p, - phi::stream::stream_t stream) { - auto chunk_size = FLAGS_auto_growth_chunk_size_in_mb << 20; - VLOG(4) << "FLAGS_auto_growth_chunk_size_in_mb is " - << FLAGS_auto_growth_chunk_size_in_mb; - - auto custom_allocator = - std::make_shared(p); - auto alignment = phi::DeviceManager::GetMinChunkSize(p); - custom_device_allocators_[p][stream] = - std::make_shared( - custom_allocator, alignment, chunk_size, allow_free_idle_chunk_); - } - void WrapStreamSafeCustomDeviceAllocator(platform::CustomPlace p, phi::stream::stream_t stream) { std::shared_ptr& allocator = @@ -1167,18 +1208,6 @@ class AllocatorFacadePrivate { allocator = std::make_shared(allocator, p, stream); } - - void InitStreamSafeCustomDeviceAllocator(platform::CustomPlace p, - phi::stream::stream_t stream) { - VLOG(8) << "Init CustomDevice allocator for stream " << stream - << " in place " << p; - if (strategy_ == AllocatorStrategy::kAutoGrowth) { - InitAutoGrowthCustomDeviceAllocator(p, stream); - } else { - InitNaiveBestFitCustomDeviceAllocator(p, stream); - } - WrapStreamSafeCustomDeviceAllocator(p, stream); - } #endif void InitSystemAllocators() { @@ -1419,12 +1448,20 @@ AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, const phi::Stream& stream) { #ifdef PADDLE_WITH_CUSTOM_DEVICE if (platform::is_custom_place(place)) { + if (!GetPrivate()->IsStreamSafeCUDAAllocatorUsed()) { + VLOG(6) << "Warning: StreamSafeCustomDeviceAllocator is not used!"; + return Alloc(place, size); + } platform::CustomPlace p(place); - phi::stream::stream_t s = - reinterpret_cast(stream.id()); - return GetPrivate() - ->GetAllocator(p, s, /* create_if_not_found = */ true) - ->Allocate(size); + if (LIKELY(size > 0 && FLAGS_use_system_allocator == false)) { + phi::stream::stream_t s = + reinterpret_cast(stream.id()); + return GetPrivate() + ->GetAllocator(p, s, /* create_if_not_found = */ true) + ->Allocate(size); + } else { + return GetPrivate()->GetAllocator(p, size)->Allocate(size); + } } #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -1552,10 +1589,32 @@ void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) { #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE +uint64_t AllocatorFacade::Release(const platform::CustomPlace& place, + phi::stream::stream_t stream) { + AllocatorFacadePrivate* m = GetPrivate(); + if (!m->IsStreamSafeCUDAAllocatorUsed()) { + VLOG(6) << "Warning: StreamSafeCustomDeviceAllocator is not used!"; + return Release(place); + } + + return m->GetAllocator(place, stream)->Release(place); +} + +void AllocatorFacade::RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream) { + GetPrivate()->RecordStream(allocation, stream); +} + const std::shared_ptr& AllocatorFacade::GetAllocator( const platform::Place& place, phi::stream::stream_t stream) { AllocatorFacadePrivate* m = GetPrivate(); + if (!m->IsStreamSafeCUDAAllocatorUsed()) { + VLOG(6) << "Warning: StreamSafeCustomDeviceAllocator is not used!"; + return GetAllocator(place); + } + + if (platform::is_custom_place(place) && FLAGS_use_system_allocator == false) { return m->GetAllocator(place, stream, /*create_if_not_found=*/true); @@ -1563,9 +1622,9 @@ const std::shared_ptr& AllocatorFacade::GetAllocator( return m->GetAllocator(place, /* A non-zero num to choose allocator_ */ 1); } -void AllocatorFacade::RecordStream(std::shared_ptr allocation, - phi::stream::stream_t stream) { - GetPrivate()->RecordStream(allocation, stream); +phi::stream::stream_t AllocatorFacade::GetStream( + const std::shared_ptr& allocation) const { + return GetPrivate()->GetStream(allocation); } void AllocatorFacade::SetDefaultStream(const platform::CustomPlace& place, diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index 0131d56c6f6428..9d2c85eccf4555 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -97,11 +97,14 @@ class AllocatorFacade { #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE + uint64_t Release(const platform::CustomPlace& place, + phi::stream::stream_t stream); + void RecordStream(std::shared_ptr allocation, + phi::stream::stream_t stream); const std::shared_ptr& GetAllocator(const platform::Place& place, phi::stream::stream_t stream); - void RecordStream(std::shared_ptr allocation, - phi::stream::stream_t stream); - + phi::stream::stream_t GetStream( + const std::shared_ptr& allocation) const; void SetDefaultStream(const platform::CustomPlace& place, phi::stream::stream_t stream); #endif diff --git a/paddle/fluid/memory/allocation/cpu_allocator.cc b/paddle/fluid/memory/allocation/cpu_allocator.cc index dde362ebed4ef7..398c015627860d 100644 --- a/paddle/fluid/memory/allocation/cpu_allocator.cc +++ b/paddle/fluid/memory/allocation/cpu_allocator.cc @@ -38,7 +38,7 @@ void CPUAllocator::FreeImpl(phi::Allocation *allocation) { } phi::Allocation *CPUAllocator::AllocateImpl(size_t size) { - void *p; + void *p = nullptr; #ifdef _WIN32 p = _aligned_malloc(size, kAlignment); #else diff --git a/paddle/fluid/memory/allocation/mmap_allocator.cc b/paddle/fluid/memory/allocation/mmap_allocator.cc index 6be6436b4db7b0..5e857f9acb7171 100644 --- a/paddle/fluid/memory/allocation/mmap_allocator.cc +++ b/paddle/fluid/memory/allocation/mmap_allocator.cc @@ -321,7 +321,7 @@ void MemoryMapFdSet::Clear() { VLOG(3) << "PID: " << getpid() << ", MemoryMapFdSet: set size - " << fd_set_.size(); std::lock_guard guard(mtx_); - for (auto fd : fd_set_) { + for (auto const &fd : fd_set_) { int rlt = shm_unlink(fd.c_str()); if (rlt == 0) { VLOG(3) << "PID: " << getpid() << ", MemoryMapFdSet: clear " << fd; @@ -375,7 +375,7 @@ void MemoryMapAllocationPool::SetMaxPoolSize(const int &size) { void MemoryMapAllocationPool::Clear() { std::lock_guard guard(mtx_); - for (auto mmap : memory_map_allocations_) { + for (auto const &mmap : memory_map_allocations_) { int rlt = shm_unlink(mmap.file_name_.c_str()); if (rlt == 0) { VLOG(4) << "MemoryMapAllocationPool: clear " << mmap.file_name_; diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc index 9f513448eea266..a296d254266ab2 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc @@ -35,7 +35,7 @@ StreamSafeCUDAAllocation::StreamSafeCUDAAllocation( underlying_allocation->size(), underlying_allocation->place()), underlying_allocation_(std::move(underlying_allocation)), - owning_stream_(std::move(owning_stream)), + owning_stream_(owning_stream), allocator_(allocator->shared_from_this()) {} void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { @@ -148,8 +148,8 @@ StreamSafeCUDAAllocator::StreamSafeCUDAAllocator( gpuStream_t default_stream, bool in_cuda_graph_capturing) : underlying_allocator_(std::move(underlying_allocator)), - place_(std::move(place)), - default_stream_(std::move(default_stream)), + place_(place), + default_stream_(default_stream), in_cuda_graph_capturing_(in_cuda_graph_capturing) { if (LIKELY(!in_cuda_graph_capturing)) { std::lock_guard lock_guard(allocator_map_lock_); diff --git a/paddle/fluid/memory/allocation/system_allocator_test.cc b/paddle/fluid/memory/allocation/system_allocator_test.cc index e04d14f0adfde0..16b538599df258 100644 --- a/paddle/fluid/memory/allocation/system_allocator_test.cc +++ b/paddle/fluid/memory/allocation/system_allocator_test.cc @@ -26,7 +26,7 @@ PHI_DECLARE_bool(use_pinned_memory); void TestAllocator(paddle::memory::detail::SystemAllocator* a, size_t size) { bool freed = false; { - size_t index; + size_t index; // NOLINT void* p = a->Alloc(&index, size); if (size > 0) { EXPECT_NE(p, nullptr); diff --git a/paddle/fluid/operators/bilateral_slice_op.cc b/paddle/fluid/operators/bilateral_slice_op.cc index 53386c1551d0f5..1a6561fc383cc6 100644 --- a/paddle/fluid/operators/bilateral_slice_op.cc +++ b/paddle/fluid/operators/bilateral_slice_op.cc @@ -51,7 +51,7 @@ class BilateralSliceOp : public framework::OperatorWithKernel { int64_t coeffs_chans = grid_dims[1]; int64_t input_chans = input_dims[1]; - int64_t output_chans; + int64_t output_chans = 0; if ((!ctx->IsRuntime()) && ((coeffs_chans < 0) || (input_chans < 0))) { output_chans = -1; } else { diff --git a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc index 0dd43e761da391..9caca06f53ad3a 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_infer_op.cc @@ -50,7 +50,7 @@ class ConditionalBlockInferOp : public ConditionalOp { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - bool need_run; + bool need_run = false; if (Attr("is_scalar_condition")) { // When is_scalar_condition is True, the conditional variable is a scalar, // whether need to execute the operators in sub-block depends on the diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 501761d82d0343..d7166a5ad02672 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -51,7 +51,7 @@ class ConditionalBlockOp : public ConditionalOp { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - bool need_run; + bool need_run = false; if (Attr("is_scalar_condition")) { // When is_scalar_condition is True, the conditional variable is a scalar, // whether need to execute the operators in sub-block depends on the @@ -147,7 +147,7 @@ class ConditionalBlockGradOp : public ConditionalOp { private: void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { - bool need_run; + bool need_run = false; if (Attr("is_scalar_condition")) { auto xs = this->InputTensors(scope, ConditionalOp::kCondition); need_run = ScalarCondition(xs); diff --git a/paddle/fluid/operators/controlflow/get_places_op.cc b/paddle/fluid/operators/controlflow/get_places_op.cc index 9f67b1d4b6e183..9262ca59af970b 100644 --- a/paddle/fluid/operators/controlflow/get_places_op.cc +++ b/paddle/fluid/operators/controlflow/get_places_op.cc @@ -52,7 +52,7 @@ class GetPlacesOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - bool is_gpu; + bool is_gpu = false; if (Attr("device_type") == "AUTO") { is_gpu = platform::is_gpu_place(place); } else { diff --git a/paddle/fluid/operators/controlflow/pylayer_op_helper.cc b/paddle/fluid/operators/controlflow/pylayer_op_helper.cc index dabe561eea3e73..9dc53d428ef1d2 100644 --- a/paddle/fluid/operators/controlflow/pylayer_op_helper.cc +++ b/paddle/fluid/operators/controlflow/pylayer_op_helper.cc @@ -47,7 +47,7 @@ static void FindAllPyLayerOpAndPyLayerGradOp( for (size_t i = 1; i < program.Size(); ++i) { auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { - auto *op = block.Op(j); + auto *op = block.Op(static_cast(j)); if (op->Type() == "pylayer") { fwd_ops->emplace_back(op); } else if (op->Type() == "pylayer_grad") { diff --git a/paddle/fluid/operators/detection/mask_util.cc b/paddle/fluid/operators/detection/mask_util.cc index 70fdf4b8999f4e..f3e5b166b43b8a 100644 --- a/paddle/fluid/operators/detection/mask_util.cc +++ b/paddle/fluid/operators/detection/mask_util.cc @@ -42,10 +42,10 @@ void Decode(const uint32_t* cnts, int m, uint8_t* mask) { typedef uint32_t uint; void Poly2Mask(const float* xy, int k, int h, int w, uint8_t* mask) { - int j, m = 0; + int j = 0, m = 0; double scale = 5; - int *x, *y, *u, *v; - uint *a, *b; + int *x = nullptr, *y = nullptr, *u = nullptr, *v = nullptr; + uint *a = nullptr, *b = nullptr; platform::CPUPlace cpu; auto xptr = memory::Alloc(cpu, sizeof(int) * (k + 1) * 2); x = reinterpret_cast(xptr->ptr()); @@ -65,9 +65,10 @@ void Poly2Mask(const float* xy, int k, int h, int w, uint8_t* mask) { v = u + m; m = 0; for (j = 0; j < k; j++) { - int xs = x[j], xe = x[j + 1], ys = y[j], ye = y[j + 1], dx, dy, t, d; - int flip; - double s; + int xs = x[j], xe = x[j + 1], ys = y[j], ye = y[j + 1], dx = 0, dy = 0, + t = 0, d = 0; + int flip = 0; + double s = NAN; dx = abs(xe - xs); dy = abs(ys - ye); flip = (dx >= dy && xs > xe) || (dx < dy && ys > ye); @@ -100,7 +101,7 @@ void Poly2Mask(const float* xy, int k, int h, int w, uint8_t* mask) { /* get points along y-boundary and downsample */ k = m; m = 0; - double xd, yd; + double xd = NAN, yd = NAN; auto xyptr = memory::Alloc(cpu, sizeof(int) * k * 2); x = reinterpret_cast(xyptr->ptr()); y = x + k; diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 8519752bc10492..9f3f426d1ad853 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -250,7 +250,7 @@ class MultiClassNMSKernel : public framework::OpKernel { *num_nmsed_out = num_det; const T* scores_data = scores.data(); if (keep_top_k > -1 && num_det > keep_top_k) { - const T* sdata; + const T* sdata = nullptr; std::vector>> score_index_pairs; for (const auto& it : *indices) { int label = it.first; @@ -310,7 +310,7 @@ class MultiClassNMSKernel : public framework::OpKernel { auto* scores_data = scores.data(); auto* bboxes_data = bboxes.data(); auto* odata = outs->data(); - const T* sdata; + const T* sdata = nullptr; phi::DenseTensor bbox; bbox.Resize({scores.dims()[0], box_size}); int count = 0; @@ -325,7 +325,7 @@ class MultiClassNMSKernel : public framework::OpKernel { for (auto idx : indices) { odata[count * out_dim] = label; // label - const T* bdata; + const T* bdata = nullptr; if (scores_size == 3) { bdata = bboxes_data + idx * box_size; odata[count * out_dim + 1] = sdata[idx]; // score diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index a41b8a70a42833..81e8d0d3edf7e7 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -122,7 +122,7 @@ std::vector FilterStraddleAnchor( int anchor_num = static_cast(anchor->dims()[0]); auto* anchor_data = anchor->data(); if (rpn_straddle_thresh >= 0) { - int index; + int index = 0; for (int i = 0; i < anchor_num; ++i) { index = i * 4; if ((anchor_data[index + 0] >= -rpn_straddle_thresh) && diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 778e6ed277fd7e..b11840866d46b3 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -405,7 +405,7 @@ class FusedElemwiseAddActivationOp : public FusedElemwiseActivationOp { std::vector functor_names = ctx->Attrs().Get>("functor_list"); bool elemntwise_add_detected = false; - for (auto names : functor_names) { + for (auto const &names : functor_names) { if (names == "elementwise_add") { elemntwise_add_detected = true; break; @@ -430,7 +430,7 @@ class FusedElemwiseAddActivationOpGrad : public FusedElemwiseActivationOpGrad { std::vector functor_names = ctx->Attrs().Get>("functor_list"); bool elemntwise_add_grad_detected = false; - for (auto names : functor_names) { + for (auto const &names : functor_names) { if (names == "elementwise_add_grad") { elemntwise_add_grad_detected = true; break; diff --git a/paddle/fluid/operators/fused/fused_matmul_op.cc b/paddle/fluid/operators/fused/fused_matmul_op.cc index ca3d02bf9bfa11..198fd61a150780 100644 --- a/paddle/fluid/operators/fused/fused_matmul_op.cc +++ b/paddle/fluid/operators/fused/fused_matmul_op.cc @@ -82,7 +82,7 @@ class FusedMatmulOp : public framework::OperatorWithKernel { y_broadcasted = true; } - size_t M, N; + size_t M = 0, N = 0; if (trans_x) { M = dims_x[ndims_x - 1]; } else { diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 0625d5c80c08eb..541233949b5d22 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -129,7 +129,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::DDim out_dims({x_mat_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->ShareLoD("X", "Hidden"); - int xx_width; + int xx_width = 0; if (ctx->Attrs().Get("use_seq")) { xx_width = static_cast(wx_dims[1]); } else { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 400d8dcdaad2f6..d6e05a4ba3d480 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -141,7 +141,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->SetOutputDim("Cell", out_dims); ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - int xx_width; + int xx_width = 0; if (ctx->Attrs().Get("use_seq")) { xx_width = static_cast(wx_dims[1]); } else { diff --git a/paddle/fluid/operators/generator/type_mapping.py b/paddle/fluid/operators/generator/type_mapping.py index 8d3a4933c3bd0a..56e01a997e61b7 100644 --- a/paddle/fluid/operators/generator/type_mapping.py +++ b/paddle/fluid/operators/generator/type_mapping.py @@ -48,7 +48,7 @@ 'int64_t[]': 'const std::vector&', 'float[]': 'const std::vector&', 'double[]': 'const std::vector&', - 'str[]': 'const std::vector<&', + 'str[]': 'const std::vector&', } opmaker_attr_types_map = { @@ -86,8 +86,8 @@ } optional_output_type_map = { - 'Tensor': 'const paddle::optional&', - 'Tensor[]': 'const paddle::optional>&', + 'Tensor': 'const paddle::optional', + 'Tensor[]': 'const paddle::optional>', } # ------------------------------ phi attr ------------------------------ diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 315bd225809729..f199fa096d0df6 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -333,7 +333,8 @@ class GRUCPUKernel : public framework::OpKernel { auto input_dims = input->dims(); auto hidden_dims = hidden->dims(); - LodTensorPtr batch_gate, batch_reset_hidden_prev, batch_hidden; + LodTensorPtr batch_gate = nullptr, batch_reset_hidden_prev = nullptr, + batch_hidden = nullptr; phi::DenseTensor batch_gate_tmp, batch_reset_hidden_prev_tmp, batch_hidden_tmp; if (is_test) { diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 2bb9bf633f0c24..1af8b247de4479 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -61,7 +61,7 @@ static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) { return; } - int out_w; + int out_w = 0; if (ctx->HasInput("Scale")) { auto scale_tensor = ctx->GetInputDim("Scale"); PADDLE_ENFORCE_EQ( @@ -151,7 +151,7 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) { return; } - int out_h, out_w; + int out_h = 0, out_w = 0; if (ctx->HasInput("Scale")) { auto scale_tensor = ctx->GetInputDim("Scale"); PADDLE_ENFORCE_EQ( @@ -247,7 +247,7 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { return; } - int out_d, out_h, out_w; + int out_d = 0, out_h = 0, out_w = 0; if (ctx->HasInput("Scale")) { auto scale_tensor = ctx->GetInputDim("Scale"); PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/operators/math/tree2col.cc b/paddle/fluid/operators/math/tree2col.cc index 21eeb52fd311a3..41c131de0f392a 100644 --- a/paddle/fluid/operators/math/tree2col.cc +++ b/paddle/fluid/operators/math/tree2col.cc @@ -98,7 +98,7 @@ class Tree2ColFunctor { phi::funcs::SetConstant constant; int64_t feature_size = feature_dims[1]; size_t patch_elem_size = 3 * static_cast(feature_size); - size_t node_count = 0, patch_count = 0, patch_size; + size_t node_count = 0, patch_count = 0, patch_size = 0; Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); std::vector> processing_list; for (size_t u = 1; u <= node_count; u++) { diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 69c64de7056459..df66ab400f40bf 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -181,7 +181,7 @@ static phi::DenseTensor FoldHeadAndLastDims(const DeviceContext &context, */ static void ReshapeTensorIntoMatrixSequence( phi::DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) { - int64_t h, w; + int64_t h = 0, w = 0; h = descriptor.height_; w = descriptor.width_; if (descriptor.trans_) { diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index ae00939d07844f..3f0fd7bfef2dcc 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -82,7 +82,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { platform::Place place = dev_place; int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; auto data_type = in_true.IsInitialized() ? in_true.type() : in_false.type(); - int rank; + int rank = 0; framework::DDim in_dims; if (in_true.IsInitialized()) { rank = in_true.dims().size(); diff --git a/paddle/fluid/operators/prim_ops/split_p_op.cc b/paddle/fluid/operators/prim_ops/split_p_op.cc index 0584de504e7706..dea336bcd263fe 100644 --- a/paddle/fluid/operators/prim_ops/split_p_op.cc +++ b/paddle/fluid/operators/prim_ops/split_p_op.cc @@ -110,7 +110,7 @@ class SplitPrimOpVarTypeInference void operator()(framework::InferVarTypeContext *ctx) const override { auto x_name = Input(ctx, "X")[0]; auto y_names = Output(ctx, "YS"); - for (auto y_name : y_names) { + for (auto const &y_name : y_names) { SetType(ctx, y_name, GetType(ctx, x_name)); SetDataType(ctx, y_name, GetDataType(ctx, x_name)); } diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index cf8197a04dd695..5cea8f59631119 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -53,7 +53,7 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { place_str = place_str.substr(0, place_str.length() - 1); std::istringstream sin(place_str); sin.seekg(std::string("PLACE(GPU:").size(), std::ios::beg); // NOLINT - size_t num; + size_t num = 0; sin >> num; place = platform::CUDAPlace(static_cast(num)); } diff --git a/paddle/fluid/operators/reader/py_reader.cc b/paddle/fluid/operators/reader/py_reader.cc index 2db8ac6b1bcb9b..f0c0409a729a5c 100644 --- a/paddle/fluid/operators/reader/py_reader.cc +++ b/paddle/fluid/operators/reader/py_reader.cc @@ -31,7 +31,7 @@ PyReader::PyReader( } void PyReader::ReadNext(paddle::framework::LoDTensorArray* out) { - bool success; + bool success = false; *out = queue_->Pop(&success); if (!success) out->clear(); } diff --git a/paddle/fluid/operators/split_lod_tensor_op.cc b/paddle/fluid/operators/split_lod_tensor_op.cc index e648575a1edca1..6b79d5c35b7838 100644 --- a/paddle/fluid/operators/split_lod_tensor_op.cc +++ b/paddle/fluid/operators/split_lod_tensor_op.cc @@ -107,7 +107,7 @@ class SplitLoDTensorOp : public framework::OperatorBase { } for (size_t t = 0; t < 2; ++t) { - phi::DenseTensor *out; + phi::DenseTensor *out = nullptr; if (t == 0) { out = out_false; } else { diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 246e9368a7d2fc..f362b05fc6d10c 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -65,12 +65,12 @@ class SplitOp : public framework::OperatorWithKernel { if (ctx->IsRuntime() && ctx->HasInput("AxisTensor")) { Variable *var = PADDLE_GET_CONST(Variable *, ctx->GetInputVarPtrs("AxisTensor")[0]); - axis_final = std::move(framework::MakePhiScalarFromVar(*var)); + axis_final = framework::MakePhiScalarFromVar(*var); } else if (!ctx->IsRuntime() && ctx->HasInput("AxisTensor")) { - axis_final = std::move(phi::Scalar(-1)); + axis_final = phi::Scalar(-1); axis_final.SetFromTensor(true); } else { - axis_final = std::move(phi::Scalar(axis)); + axis_final = phi::Scalar(axis); } // Construct sections_final diff --git a/paddle/fluid/operators/string/faster_tokenizer_op.cc b/paddle/fluid/operators/string/faster_tokenizer_op.cc index dd1e421e6cb1ae..68126e187b4e58 100644 --- a/paddle/fluid/operators/string/faster_tokenizer_op.cc +++ b/paddle/fluid/operators/string/faster_tokenizer_op.cc @@ -131,13 +131,13 @@ void WordPieceTokenizer::Tokenize(const wstring& text, vector* token_ids) const { size_t len = text.size(); if (len > max_input_chars_per_word_) { - token_ids->emplace_back(std::move(unk_token_id_)); + token_ids->emplace_back(unk_token_id_); return; } auto it = vocab_->find(text); if (it != vocab_->end()) { - token_ids->emplace_back(std::move(it->second)); + token_ids->emplace_back(it->second); return; } @@ -146,7 +146,7 @@ void WordPieceTokenizer::Tokenize(const wstring& text, while (start < len) { size_t end = len; std::wstring cur_substr; - int64_t cur_substr_id; + int64_t cur_substr_id = 0; while (start < end) { std::wstring sub = text.substr(start, end - start); if (start > 0) { @@ -162,15 +162,15 @@ void WordPieceTokenizer::Tokenize(const wstring& text, } if (cur_substr.empty()) { - token_ids->emplace_back(std::move(unk_token_id_)); + token_ids->emplace_back(unk_token_id_); return; } else { start = end; - wordpiece_ids.emplace_back(std::move(cur_substr_id)); + wordpiece_ids.emplace_back(cur_substr_id); } } for (auto& token_id : wordpiece_ids) { - token_ids->emplace_back(std::move(token_id)); + token_ids->emplace_back(token_id); } } @@ -219,9 +219,9 @@ void BertTokenizer::Tokenize(const string& text, if (IsChineseChar(w_token[0])) { auto vocab_it = vocab_->find(w_token); if (vocab_it != vocab_->end()) { - split_token_ids->emplace_back(std::move(vocab_it->second)); + split_token_ids->emplace_back(vocab_it->second); } else { - split_token_ids->emplace_back(std::move(unk_token_id_)); + split_token_ids->emplace_back(unk_token_id_); } } else { word_piece_tokenizer_.Tokenize(w_token, split_token_ids); @@ -241,29 +241,29 @@ void BertTokenizer::BuildInputsWithSpecialTokens( if (token_ids_1.empty()) { inputs->clear(); inputs->resize(token_ids_0.size() + 2); - inputs->at(0) = std::move(cls_token_id_); + inputs->at(0) = cls_token_id_; size_t i = 1; for (auto& token_id : token_ids_0) { - inputs->at(i) = std::move(token_id); + inputs->at(i) = token_id; ++i; } - inputs->at(i) = std::move(sep_token_id_); + inputs->at(i) = sep_token_id_; } else { inputs->clear(); inputs->resize(token_ids_0.size() + token_ids_1.size() + 3); - inputs->at(0) = std::move(cls_token_id_); + inputs->at(0) = cls_token_id_; size_t i = 1; for (auto& token_id : token_ids_0) { - inputs->at(i) = std::move(token_id); + inputs->at(i) = token_id; ++i; } - inputs->at(i) = std::move(sep_token_id_); + inputs->at(i) = sep_token_id_; ++i; for (auto& token_id : token_ids_1) { - inputs->at(i) = std::move(token_id); + inputs->at(i) = token_id; ++i; } - inputs->at(i) = std::move(sep_token_id_); + inputs->at(i) = sep_token_id_; } } @@ -333,9 +333,9 @@ int BertTokenizer::Encode( wstring token = unicode_text.substr(i, 1); auto it = vocab_->find(token); if (it != vocab_->end()) { - ids.emplace_back(std::move(it->second)); + ids.emplace_back(it->second); } else { - ids.emplace_back(std::move(unk_token_id_)); + ids.emplace_back(unk_token_id_); } } } diff --git a/paddle/fluid/pir/CMakeLists.txt b/paddle/fluid/pir/CMakeLists.txt index 1ff77c6d7187e0..24f5e2892de8e2 100644 --- a/paddle/fluid/pir/CMakeLists.txt +++ b/paddle/fluid/pir/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(dialect) add_subdirectory(transforms) +add_subdirectory(drr) diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 62c1129f846209..8ad46bc8906adb 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -25,7 +25,7 @@ const char* PhiKernelOp::attributes_name[attributes_num] = { // NOLINT "kernel_name", "kernel_key"}; -void PhiKernelOp::Verify() { +void PhiKernelOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp."; auto& attributes = this->attributes(); @@ -64,7 +64,7 @@ const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT "kernel_name", "kernel_key"}; -void LegacyKernelOp::Verify() { +void LegacyKernelOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp."; auto& attributes = this->attributes(); diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index 8a18959665e0c7..a96aa5732d5806 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -29,7 +29,7 @@ class PhiKernelOp : public pir::Op { std::string op_name(); std::string kernel_name(); phi::KernelKey kernel_key(); - void Verify(); + void VerifySig(); }; class LegacyKernelOp : public pir::Op { @@ -41,7 +41,7 @@ class LegacyKernelOp : public pir::Op { std::string op_name(); std::string kernel_name(); phi::KernelKey kernel_key(); - void Verify(); + void VerifySig(); }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 9f51351f6ea044..c336dc7b61be18 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -197,7 +197,9 @@ def _is_optional_input(self, op_info, input_name): return True return False - def _is_optinonal_output(self, op_info, output_name): + def _is_optional_output(self, op_info, op_name, output_name): + if op_name.endswith(('_grad', '_grad_')): + return False inplace_map = op_info.inplace_map input_optional_list = op_info.input_optional_list input_name_list = op_info.input_name_list @@ -271,7 +273,7 @@ def _gen_api_args( ) return (inputs + ', ' + attrs).strip(', ') - def _gen_ret_type(self, op_info): + def _gen_ret_type(self, op_info, op_name): name_list = op_info.output_name_list type_list = op_info.output_type_list intermediate_list = op_info.output_intermediate_list @@ -285,7 +287,7 @@ def _gen_ret_type(self, op_info): ): if intermediate == 'true': continue - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): ret.append(OPTIONAL_OUTPUT_TYPE_MAP[type]) else: ret.append(OUTPUT_TYPE_MAP[type]) @@ -293,7 +295,7 @@ def _gen_ret_type(self, op_info): elif output_num == 1: index = intermediate_list.index('false') name = name_list[index] - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): return OPTIONAL_OUTPUT_TYPE_MAP[type_list[index]] else: return OUTPUT_TYPE_MAP[type_list[index]] @@ -304,7 +306,7 @@ def _gen_one_declare( self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr ): return API_DECLARE_TEMPLATE.format( - ret_type=self._gen_ret_type(op_info), + ret_type=self._gen_ret_type(op_info, op_name), api_name=op_name, args=self._gen_api_args( op_info, True, is_mutable_attr, is_vector_mutable_attr @@ -367,7 +369,7 @@ def _gen_handle_optional_outputs(self, op_info, op_name): ): if intermediate == 'true': continue - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): if VECTOR_TYPE in type: ret += OPTIONAL_VECTOR_OPRESULT_OUTPUT_TEMPLATE.format( name=name, @@ -461,7 +463,7 @@ def _gen_compute_op( op_inst_name, ) - def _gen_out_split_and_ret_list(self, op_info, op_inst_name): + def _gen_out_split_and_ret_list(self, op_info, op_name, op_inst_name): name_list = op_info.output_name_list type_list = op_info.output_type_list intermediate_list = op_info.output_intermediate_list @@ -480,7 +482,7 @@ def _gen_out_split_and_ret_list(self, op_info, op_inst_name): ): if intermediate == 'true': continue - if self._is_optinonal_output(op_info, name): + if self._is_optional_output(op_info, op_name, name): ret_list.append(f'optional_{name}') elif VECTOR_TYPE in type: split_op_name = f'{name}_split_op' @@ -503,7 +505,7 @@ def _gen_return_result(self, ret_list): def _gen_one_impl( self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr ): - ret_type = self._gen_ret_type(op_info) + ret_type = self._gen_ret_type(op_info, op_name) in_combine, in_combine_op_list = self._gen_in_combine( op_info, is_mutable_attr, is_vector_mutable_attr ) @@ -514,7 +516,7 @@ def _gen_one_impl( compute_op += f' (void){op_inst_name};' out_split, ret_list = self._gen_out_split_and_ret_list( - op_info, op_inst_name + op_info, op_name, op_inst_name ) ret = API_IMPL_TEMPLATE.format( diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py new file mode 100644 index 00000000000000..2c559330eec99c --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -0,0 +1,22 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================================== +# DecompInterface gen op list +# ===================================== + + +decomp_interface_declare_gen_op_list = ['mean'] + +decomp_interface_implementation_gen_op_list = ["mean"] diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index e24902c712c1a7..ba78e7d7dc722d 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -144,6 +144,7 @@ def GenBuildInputArgsStr( 'int': 'phi::DataType::INT32', 'int64_t': 'phi::DataType::INT64', 'float': 'phi::DataType::FLOAT32', + 'double': 'phi::DataType::FLOAT64', 'std::vector': 'phi::DataType::INT64', 'const std::vector&': 'phi::DataType::INT64', 'bool': 'phi::DataType::BOOL', diff --git a/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py new file mode 100644 index 00000000000000..c760d7fb85b84e --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import yaml +from op_gen import OpCompatParser, OpInfoParser, to_pascal_case + +CPP_FILE_TEMPLATE = """ +#include "paddle/fluid/pir/drr/ir_operation_factory.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" + +namespace pir {{ +namespace drr {{ + +void OperationFactory::RegisterGeneratedOpCreator() {{ +{body} +}} + +}} // namespace drr +}} // namespace pir + +""" + +NORMAL_FUNCTION_TEMPLATE = """ + RegisterOperationCreator( + "{op_name}", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) {{ + return rewriter.Build( + {params_code}); + }}); +""" + +MUTABLE_ATTR_FUNCTION_TEMPLATE = """ + RegisterOperationCreator( + "{op_name}", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) {{ + // mutable_attr is tensor + if (inputs.size() > {inputs_num}) {{ + return rewriter.Build( + {params_code_with_mutable_attr}); + }} else {{ + return rewriter.Build( + {params_code_no_mutable_attr}); + }} + }}); +""" + + +class OpCreatorCodeGen: + def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name): + self.op_info_items = self.parse_yaml(op_yaml_files, op_compat_yaml_file) + self.dialect_name = dialect_name + + def parse_yaml(self, op_yaml_files, op_compat_yaml_file): + op_compat_parser = OpCompatParser(op_compat_yaml_file) + + op_yaml_items = [] + for yaml_file in op_yaml_files: + with open(yaml_file, "r") as f: + ops = yaml.safe_load(f) + op_yaml_items = op_yaml_items + ops + op_info_items = [] + for op in op_yaml_items: + op_compat_item = op_compat_parser.get_compat(op['name']) + if ( + op_compat_item is not None + and op_compat_item['op'] == "pow" + and 'scalar' in op_compat_item + ): + op_compat_item = op_compat_item.pop('scalar') + op_info_items.append(OpInfoParser(op, op_compat_item)) + return op_info_items + + def gen_cpp_file_code(self, cpp_file_path): + body_code = "" + for op_info_item in self.op_info_items: + if op_info_item.infer_meta_map is None: + continue + for phi_op_name in op_info_item.op_phi_name: + ir_op_name = self.dialect_name + "." + phi_op_name + params_no_mutable_attr = [] + for i in range(len(op_info_item.input_name_list)): + params_no_mutable_attr.append( + f"inputs[{i}].dyn_cast()" + ) + if len(op_info_item.attribute_name_list) > 0: + params_no_mutable_attr.append("attrs") + + if len(op_info_item.mutable_attribute_name_list) == 0: + body_code += NORMAL_FUNCTION_TEMPLATE.format( + op_name=ir_op_name, + op_class_name=(to_pascal_case(phi_op_name) + "Op"), + params_code=", ".join(params_no_mutable_attr), + ) + else: + params_with_mutable_attr = [] + for i in range( + len(op_info_item.input_name_list) + + len(op_info_item.mutable_attribute_name_list) + ): + params_with_mutable_attr.append( + f"inputs[{i}].dyn_cast()" + ) + if len(op_info_item.attribute_name_list) > len( + op_info_item.mutable_attribute_name_list + ): + # TODO(zyfncg): Currently Op::Build Interface doesn't support this case. + continue + # params_with_mutable_attr.append("attrs") + + body_code += MUTABLE_ATTR_FUNCTION_TEMPLATE.format( + op_name=ir_op_name, + inputs_num=len(op_info_item.input_name_list), + op_class_name=(to_pascal_case(phi_op_name) + "Op"), + params_code_with_mutable_attr=",".join( + params_with_mutable_attr + ), + params_code_no_mutable_attr=", ".join( + params_no_mutable_attr + ), + ) + + with open(cpp_file_path, 'w') as f: + f.write(CPP_FILE_TEMPLATE.format(body=body_code)) + + +def ParseArguments(): + parser = argparse.ArgumentParser( + description='Generate Op Creator Files By Yaml' + ) + parser.add_argument('--op_yaml_files', type=str) + parser.add_argument('--op_compat_yaml_file', type=str) + parser.add_argument('--dialect_name', type=str) + parser.add_argument('--op_creator_file', type=str) + return parser.parse_args() + + +if __name__ == '__main__': + args = ParseArguments() + op_yaml_files = args.op_yaml_files.split(",") + op_compat_yaml_file = args.op_compat_yaml_file + op_creator_file = args.op_creator_file + dialect_name = args.dialect_name + + code_gen = OpCreatorCodeGen( + op_yaml_files, op_compat_yaml_file, dialect_name + ) + code_gen.gen_cpp_file_code(op_creator_file) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 167b950ee95e7c..8983ffa38b5629 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -19,6 +19,7 @@ import sys import yaml +from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke from op_interface_gen import ( gen_exclusive_interface_str, @@ -27,10 +28,7 @@ ) from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str -from vjp_interface_gen_op_list import ( - vjp_interface_declare_gen_op_list, - vjp_interface_implementation_gen_op_list, -) +from vjp_interface_black_list import vjp_interface_black_list # import from paddle/fluid/primitive/code_gen/gen.py sys.path.append( @@ -61,6 +59,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/framework/infershape_utils.h" @@ -99,7 +98,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {build_mutable_attr_is_input} {build_attr_num_over_1} {build_mutable_attr_is_input_attr_num_over_1} - void Verify(); + void VerifySig(); {get_inputs_and_outputs} {exclusive_interface} }}; @@ -477,7 +476,7 @@ def parse_mutable_attribute(self): if (self.op_compat_item['op'] == "isclose") or ( self.op_compat_item['op'] == "allclose" ): - data_type = "float" + data_type = "double" mutable_attribute_type_list.append( [ "paddle::dialect::ScalarAttribute", @@ -1036,9 +1035,11 @@ def OpGenerator( if ( op_info.backward_name - and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list + and op_info.op_phi_name[0] not in vjp_interface_black_list ): op_interfaces += ["paddle::dialect::VjpInterface"] + if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list: + op_interfaces += ["paddle::dialect::DecompInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) @@ -1444,7 +1445,7 @@ def OpGenerator( if ( op_info.backward_name and op_info.op_phi_name[0] - in vjp_interface_implementation_gen_op_list + not in vjp_interface_black_list ): op_vjp_str = gen_op_vjp_str( op_class_name, diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 9c8ff889f2b219..299d4197b79475 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list + # generator interfaces -from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list +from vjp_interface_black_list import vjp_interface_black_list OP_INFER_SHAPE_TEMPLATE = """ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ @@ -26,12 +28,12 @@ {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """ - pir::CombineOp combine_op_obj = + pir::CombineOp combine_op_obj_{input_name} = op_obj.{input_name}().dyn_cast().owner()->dyn_cast(); std::vector {input_name}; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ + for (size_t idx = 0; idx < combine_op_obj_{input_name}.inputs().size(); idx++) {{ {input_name}.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared(combine_op_obj_{input_name}.inputs()[idx])); }}""" OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE = """ @@ -63,6 +65,23 @@ std::make_shared(out_grads[{index}][idx])); }}""" +OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE = """ + paddle::optional {output_grad_name}; + if (!IsEmptyValue(out_grads[{idx1}][{idx2}])){{ + {output_grad_name} = paddle::make_optional(Tensor(std::make_shared(out_grads[{idx1}][{idx2}]))); + }}""" + +OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE = """ + paddle::optional> {output_grad_name}; + std::vector optional_{output_grad_name}; + if (!IsEmptyValue(out_grads[{index}])){{ + for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{ + optional_{output_grad_name}.emplace_back( + std::make_shared(out_grads[{index}][idx])); + }} + {output_grad_name} = paddle::make_optional>(optional_{output_grad_name}); + }}""" + OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();""" @@ -131,26 +150,25 @@ def gen_op_vjp_str( grad_idx = -1 for idx in range(len(bw_input_list)): build_args_str += bw_input_list[idx] + ", " - if op_grad_info.input_optional_list[idx] == 'true': - input_type = input_types_map[op_grad_info.input_type_list[idx]] - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], + input_type = input_types_map[op_grad_info.input_type_list[idx]] + if ( + bw_input_list[idx] in op_info.input_name_list + or bw_input_list[idx] in op_info.output_name_list + ): + if op_grad_info.input_optional_list[idx] == 'true': + if input_type == 'Tensor': + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], + ) ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], + else: + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], + ) ) - ) - else: - if ( - bw_input_list[idx] in op_info.input_name_list - or bw_input_list[idx] in op_info.output_name_list - ): - input_type = input_types_map[op_grad_info.input_type_list[idx]] + else: if input_type == 'Tensor': forward_input_output_code += ( OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( @@ -164,9 +182,22 @@ def gen_op_vjp_str( input_name=bw_input_list[idx], ) ) + else: + grad_idx += 1 + if op_grad_info.input_optional_list[idx] == 'true': + if input_type == 'Tensor': + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], + idx1=grad_idx, + idx2=0, + ) + ) + else: + forward_input_output_code += OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], index=grad_idx + ) else: - grad_idx += 1 - input_type = input_types_map[op_grad_info.input_type_list[idx]] if input_type == 'Tensor': forward_output_grad_code += ( OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( @@ -285,6 +316,8 @@ def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str += ( " static void InferMeta( phi::InferMetaContext *infer_meta );" ) - if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: + if op_info.op_phi_name[0] not in vjp_interface_black_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" + if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list: + exclusive_interface_str += "\n static std::vector> Decomp(pir::Operation* op);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py index 1b8c82b27d90be..3a2515f278915a 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py @@ -14,7 +14,7 @@ # verify OP_VERIFY_TEMPLATE = """ -void {op_name}::Verify() {{ +void {op_name}::VerifySig() {{ VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}."; VLOG(4) << "Verifying inputs:"; {{ @@ -36,7 +36,7 @@ """ GRAD_OP_VERIFY_TEMPLATE = """ -void {op_name}::Verify() {{}} +void {op_name}::VerifySig() {{}} """ INPUT_TYPE_CHECK_TEMPLATE = """ diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index fec69b8ce5a4ec..e2d17e7f118023 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -71,32 +71,33 @@ NEED_GEN_STATIC_ONLY_APIS = ['fetch'] NO_NEED_GEN_STATIC_ONLY_APIS = [ - 'set_value_with_tensor', - 'set_value_with_tensor_', - 'fused_bn_add_activation_', - 'fused_batch_norm_act_', 'add_n_', - 'set_value', - 'assign_value', - 'set_value_', - 'embedding_grad_sparse', 'add_n_with_kernel', - 'print', - 'send_v2', - 'shadow_feed', - 'recv_v2', - 'rnn_', - 'fused_scale_bias_relu_conv_bnstats', + 'assign_value', 'batch_norm_', + 'c_allgather', + 'c_allreduce_max', 'c_allreduce_sum', 'c_embedding', 'c_identity', 'c_reduce_sum', - 'c_allreduce_max', - 'c_allgather', + 'dpsgd', + 'embedding_grad_sparse', + 'fused_attention', + 'fused_batch_norm_act_', + 'fused_bn_add_activation_', + 'fused_feedforward', + 'fused_scale_bias_relu_conv_bnstats', + 'print', + 'recv_v2', + 'rnn_', 'seed', - "fused_attention", - "fused_feedforward", + 'send_v2', + 'set_value', + 'set_value_', + 'set_value_with_tensor', + 'set_value_with_tensor_', + 'shadow_feed', ] diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py new file mode 100644 index 00000000000000..c63e0c4e418338 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================================== +# VjpInterface gen op list +# ===================================== +# we don't support vjp function code +# gen now, so we use a whitelist to +# control the generation of Vjp methods. +# TODO(wanghao107) +# remove this file and support Vjp methods +# code gen. + + +vjp_interface_black_list = [ + 'frobenius_norm', + 'write_to_array', + 'fused_attention', + 'fused_feedforward', + 'set_value', + 'set_value_with_tensor', + 'silu_grad', + 'fused_dropout_add', + 'fused_rotary_position_embedding', +] diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py deleted file mode 100644 index 3a559ef8dedf84..00000000000000 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ===================================== -# VjpInterface gen op list -# ===================================== -# we don't support vjp function code -# gen now, so we use a whitelist to -# control the generation of Vjp methods. -# TODO(wanghao107) -# remove this file and support Vjp methods -# code gen. - - -vjp_interface_declare_gen_op_list = [ - 'where', - "tanh", - "mean", - "divide", - "sum", - "add", - "concat", - "split", - "split_with_num", - "gelu", - "matmul", - "erf", - "multiply", - "pow", - "rsqrt", - "subtract", - "square", - "dropout", - 'exp', - 'expm1', - 'expand', - 'layer_norm', - 'reshape', - 'cast', - "scale", - 'softmax', - 'silu', - 'elementwise_pow', - 'embedding', - 'fused_softmax_mask_upper_triangle', - 'slice', - 'transpose', - 'slice_grad', - 'gather_nd', - 'stack', - 'poisson', - 'gumbel_softmax', - 'pad', - 'pad3d', - 'squeeze', - 'unsqueeze', - 'tril', - 'triu', - 'squeeze', - 'unsqueeze', - 'conv2d', - 'depthwise_conv2d', - 'sqrt', - 'flatten', - 'relu', - 'abs', - 'log', - 'clip', - 'ceil', - 'p_norm', - 'maximum', - 'argsort', - 'min', - 'batch_norm', - 'max_pool2d_with_index', - 'pool2d', - 'minimum', - 'prod', - 'round', - 'sin', - 'cos', - 'dot', - 'floor', - 'topk', - 'square', - 'gather', - 'label_smooth', - 'cross_entropy_with_softmax', - 'mean_all', - 'cumsum', - 'linear_interp', - 'bilinear_interp', - 'trilinear_interp', - 'nearest_interp', - 'bicubic_interp', - 'assign', - 'assign_out_', - 'real', - 'flip', - 'softmax', - 'expand', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'sigmoid', - 'pad', - 'pad3d', - 'einsum', - 'leaky_relu', - 'log10', - 'conv3d', - 'solve', - 'diag', - 'trace', - 'tile', -] -vjp_interface_implementation_gen_op_list = [ - 'where', - "tanh", - "mean", - "divide", - "sum", - "add", - "concat", - "split", - "split_with_num", - "gelu", - "matmul", - "erf", - "multiply", - "subtract", - "pow", - "rsqrt", - "square", - "dropout", - 'exp', - 'expm1', - 'expand', - 'layer_norm', - 'reshape', - 'cast', - "scale", - 'softmax', - 'silu', - 'elementwise_pow', - 'embedding', - 'fused_softmax_mask_upper_triangle', - 'slice', - 'transpose', - 'slice_grad', - 'gather_nd', - 'stack', - 'poisson', - 'gumbel_softmax', - 'pad', - 'pad3d', - 'squeeze', - 'unsqueeze', - 'tril', - 'triu', - 'squeeze', - 'unsqueeze', - 'conv2d', - 'depthwise_conv2d', - 'sqrt', - 'flatten', - 'relu', - 'abs', - 'log', - 'clip', - 'ceil', - 'p_norm', - 'maximum', - 'argsort', - 'min', - 'batch_norm', - 'max_pool2d_with_index', - 'pool2d', - 'minimum', - 'prod', - 'round', - 'sin', - 'cos', - 'dot', - 'floor', - 'topk', - 'square', - 'gather', - 'label_smooth', - 'cross_entropy_with_softmax', - 'mean_all', - 'cumsum', - 'linear_interp', - 'bilinear_interp', - 'trilinear_interp', - 'nearest_interp', - 'bicubic_interp', - 'assign', - 'assign_out_', - 'real', - 'flip', - 'softmax', - 'expand', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'sigmoid', - 'pad', - 'pad3d', - 'einsum', - 'leaky_relu', - 'log10', - 'conv3d', - 'solve', - 'diag', - 'trace', - 'tile', -] diff --git a/paddle/fluid/pir/dialect/operator/interface/decomp.h b/paddle/fluid/pir/dialect/operator/interface/decomp.h new file mode 100644 index 00000000000000..10a6e51e7db3c6 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/decomp.h @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { +class DecompInterface : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept( + std::vector> (*decomp)(pir::Operation* op)) + : decomp_(decomp) {} + std::vector> (*decomp_)(pir::Operation* op); + }; + + template + struct Model : public Concept { + static std::vector> Decomp(pir::Operation* op) { + return ConcreteOp::Decomp(op); + } + Model() : Concept(Decomp) {} + }; + + /// Constructor + DecompInterface(pir::Operation* op, Concept* impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + std::vector> Decomp(pir::Operation* op) { + return impl_->decomp_(op); + } + + private: + Concept* impl_; +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index ce8bdb6c6829f8..8a4049ff09544b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" @@ -37,3 +38,4 @@ std::vector> VjpInterface::Vjp( IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) diff --git a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt index 7954e000baf519..c2209089dacb30 100644 --- a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt @@ -209,6 +209,6 @@ target_include_directories(pd_op_dialect_api INTERFACE ${PD_DIALECT_BINARY_DIR}) cc_library( pd_op_dialect - SRCS op_dialect.cc manual_op_vjp.cc ${op_vjp_source_file} + SRCS op_dialect.cc manual_op_decomp.cc manual_op_vjp.cc ${op_vjp_source_file} DEPS pd_op_dialect_api param_to_variable primitive_vjp_experimental pd_op_dialect_utils op_yaml_info_parser) diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 557f8c71060001..c235799633896b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -19,8 +19,10 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/operation_utils.h" +#include "paddle/pir/core/utils.h" #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" namespace paddle { @@ -109,41 +111,112 @@ void IfOp::Print(pir::IrPrinter &printer) { } os << "\n }"; } -void IfOp::Verify() {} +void IfOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: IfOp."; + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", input_size)); + + if ((*this)->operand_source(0).type().isa()) { + PADDLE_ENFORCE( + (*this) + ->operand_source(0) + .type() + .dyn_cast() + .dtype() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th input, it should be a " + "bool DenseTensorType.")); + } + + PADDLE_ENFORCE_EQ((*this)->num_regions(), + 2u, + phi::errors::PreconditionNotMet( + "The size %d of regions must be equal to 2.", + (*this)->num_regions())); +} + +void IfOp::VerifyRegion() { + VLOG(4) << "Start Verifying sub regions for: IfOp."; + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", + (*this)->region(0).size())); + + if ((*this)->num_results() != 0) { + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + (*this)->region(1).size(), + phi::errors::PreconditionNotMet("The size %d of true_region must be " + "equal to the size %d of false_region.", + (*this)->region(0).size(), + (*this)->region(1).size())); + + auto *true_last_op = (*this)->region(0).front()->back(); + auto *false_last_op = (*this)->region(1).front()->back(); + PADDLE_ENFORCE_EQ(true_last_op->isa(), + true, + phi::errors::PreconditionNotMet( + "The last of true block must be YieldOp")); + PADDLE_ENFORCE_EQ(true_last_op->num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of true block op's input must be " + "equal to IfOp's outputs num.")); + PADDLE_ENFORCE_EQ(false_last_op->isa(), + true, + phi::errors::PreconditionNotMet( + "The last of false block must be YieldOp")); + PADDLE_ENFORCE_EQ(false_last_op->num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of false block op's input must be " + "equal to IfOp's outputs num.")); + } +} void WhileOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, - const std::vector &output_types) { + pir::Value cond, + const std::vector &inputs) { + argument.AddInput(cond); argument.AddInputs(inputs); - argument.AddOutputs(output_types); - argument.AddRegions(2u); -} -pir::Block *WhileOp::cond_block() { - pir::Region &cond_region = (*this)->region(0); - if (cond_region.empty()) cond_region.emplace_back(); - return cond_region.front(); + for (auto val : inputs) { + argument.AddOutput(val.type()); + } + argument.AddRegion(nullptr); } pir::Block *WhileOp::body_block() { - pir::Region &body_region = (*this)->region(1); + pir::Region &body_region = (*this)->region(0); if (body_region.empty()) body_region.emplace_back(); return body_region.front(); } +pir::Value WhileOp::cond() { return (*this)->operand_source(0); } void WhileOp::Print(pir::IrPrinter &printer) { auto &os = printer.os; auto op = operation(); printer.PrintOpResult(op); - os << " \"" << name() << "\""; - printer.PrintOpOperands(op); - os << " -> "; - printer.PrintOpReturnType(op); - os << "{"; - for (auto item : *cond_block()) { - os << "\n "; - printer.PrintOperation(item); - } - os << "\n } do {"; + os << " = \"" << name() << "\"("; + printer.PrintValue(cond()); + os << ") ["; + auto operands = (*this)->operands_source(); + pir::PrintInterleave( + operands.begin() + 1, + operands.end(), + [&](pir::Value v) { printer.PrintValue(v); }, + [&]() { os << ", "; }); + os << "] { \n ^"; + pir::PrintInterleave( + body_block()->args_begin(), + body_block()->args_end(), + [&](pir::Value v) { printer.PrintValue(v); }, + [&]() { os << ", "; }); for (auto item : *body_block()) { os << "\n "; printer.PrintOperation(item); diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 99444f78da5688..3ad3a7c4215c22 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -41,9 +41,20 @@ class IfOp : public pir::Op { pir::Block *true_block(); pir::Block *false_block(); void Print(pir::IrPrinter &printer); // NOLINT - void Verify(); + void VerifySig(); + void VerifyRegion(); }; +/// +/// \brief The WhileOp is an operation that iterates over a loop body based on a +/// condition. It takes two inputs: cond_value and loop_vars. The output of the +/// WhileOp must have the same arity (length and structure) with loop_vars." The +/// semantics of WhileOp[outputs = while_op(cond, inputs)] are as below: +/// outputs = inputs +/// while(cond){ +/// cond, outputs = body(outputs) +/// } +/// class WhileOp : public pir::Op { public: using Op::Op; @@ -53,12 +64,13 @@ class WhileOp : public pir::Op { static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, - const std::vector &output_types); - pir::Block *cond_block(); + pir::Value cond, + const std::vector &inputs); pir::Block *body_block(); + pir::Value cond(); void Print(pir::IrPrinter &printer); // NOLINT - void Verify() {} + void VerifySig() {} + void VerifyRegion() {} }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index eb5acbf2388ea8..be652e48263301 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -100,5 +100,24 @@ pir::OpResult split_with_num_grad(const std::vector& out_grad, out_grad_combine_op.out(), axis); return split_grad_op.result(0); } + +pir::OpResult ones(const std::vector& shape, + phi::DataType dtype, + const Place& place) { + return paddle::dialect::full(shape, 1, dtype, place); +} + +pir::OpResult ones_like(pir::Value x_, + phi::DataType dtype, + const Place& place) { + return paddle::dialect::full_like(x_, 1, dtype, place); +} + +pir::OpResult zeros(const std::vector& shape, + phi::DataType dtype, + const Place& place) { + return paddle::dialect::full(shape, 0, dtype, place); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index fe579295ad5a09..a9df64a905b24d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -47,5 +47,18 @@ pir::OpResult split_with_num_grad(const std::vector& out_grad, pir::OpResult split_with_num_grad(const std::vector& out_grad, const pir::Value& axis); + +pir::OpResult ones(const std::vector& shape, + phi::DataType dtype = phi::DataType::FLOAT32, + const Place& place = phi::CPUPlace()); + +pir::OpResult ones_like(pir::Value x_, + phi::DataType dtype = phi::DataType::UNDEFINED, + const Place& place = {}); + +pir::OpResult zeros(const std::vector& shape, + phi::DataType dtype = phi::DataType::FLOAT32, + const Place& place = phi::CPUPlace()); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index eb5f1f5a536703..00ba7da80aa253 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -50,7 +50,7 @@ OpInfoTuple AddNOp::GetOpInfo() { return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n"); } -void AddNOp::Verify() { +void AddNOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNOp."; VLOG(4) << "Verifying inputs:"; { @@ -222,7 +222,7 @@ void AddN_Op::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void AddN_Op::Verify() { +void AddN_Op::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddN_Op."; VLOG(4) << "Verifying inputs:"; { @@ -345,7 +345,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void AddNWithKernelOp::Verify() { +void AddNWithKernelOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: " "AddNWithKernelOp."; VLOG(4) << "Verifying inputs:"; @@ -429,9 +429,9 @@ OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() { paddle::dialect::OpRunTimeInfo run_time_info( "FusedGemmEpilogueInferMeta", {"x", "y", "bias", "trans_x", "trans_y", "activation"}, - "", - {""}, - {""}, + {"fused_gemm_epilogue"}, + {"x", "y", "bias", "trans_x", "trans_y", "activation"}, + {}, {}, {}, {}); @@ -561,7 +561,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void FusedGemmEpilogueOp::Verify() { +void FusedGemmEpilogueOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: " "FusedGemmEpilogueOp."; VLOG(4) << "Verifying inputs:"; @@ -674,9 +674,15 @@ OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() { "trans_x", "trans_y", "activation_grad"}, - "", - {""}, - {""}, + {"fused_gemm_epilogue_grad"}, + {"x", + "y", + "reserve_space", + "out_grad", + "trans_x", + "trans_y", + "activation_grad"}, + {}, {}, {}, {}); @@ -833,7 +839,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void FusedGemmEpilogueGradOp::Verify() {} +void FusedGemmEpilogueGradOp::VerifySig() {} void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta); @@ -983,7 +989,7 @@ void SplitGradOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void SplitGradOp::Verify() { +void SplitGradOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp."; VLOG(4) << "Verifying inputs:"; { diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index c6fc7cb32b3165..317ce64feea084 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -45,7 +45,7 @@ class AddNOp : public pir::Op { pir::Value out_grad_, pir::Value axis_); - void Verify(); + void VerifySig(); pir::Value out_grad() { return operand_source(0); } pir::Value axis() { return operand_source(1); } pir::OpResult x_grad() { return result(0); } diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc new file mode 100644 index 00000000000000..e6c84ca2934774 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/primitive/composite/composite.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/op_base.h" + +// TODO(chenzhuo) +// this file will be generated in pd_op_decomp.cc + +namespace paddle { +namespace dialect { +using IntArray = paddle::experimental::IntArray; + +std::vector> MeanOp::Decomp(pir::Operation* op) { + MeanOp op_obj = op->dyn_cast(); + (void)op_obj; + + VLOG(4) << "Decomp Prepare inputs of mean"; + + Tensor x(std::make_shared(op_obj.x())); + + VLOG(4) << "Decomp prepare attributes of mean"; + + IntArray axis = op->attribute("axis") + .dyn_cast() + .data(); + + bool keepdim = op->attribute("keepdim").dyn_cast().data(); + VLOG(4) << "Decomp mean keep_dim " << keepdim; + + VLOG(4) << "Decomp prepare call mean's decomp interface"; + + Tensor op_res = + paddle::primitive::details::mean_decomp( + x, axis, keepdim); + + auto org_res = op->results(); + std::vector> res(org_res.size()); + res[0].push_back( + std::static_pointer_cast(op_res.impl()) + ->value() + .dyn_cast()); + return res; +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc index 3b69d68eb65f3d..f10db043d1523d 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_attribute.cc @@ -50,10 +50,10 @@ phi::Scalar ScalarAttribute::data() { IntArrayAttribute IntArrayAttribute::Parse(pir::IrParser &parser) { // NOLINT Token buket_token = parser.ConsumeToken(); - std::vector vec{}; + std::vector vec{}; while (parser.PeekToken().val_ != "]") { Token val_token = parser.ConsumeToken(); - vec.push_back(atoll(val_token.val_.c_str())); + vec.push_back(atoi(val_token.val_.c_str())); if (parser.PeekToken().val_ == "]") break; parser.ConsumeToken(); } diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 9a7c6b9de2ea26..e484d7812d2daa 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -126,7 +126,7 @@ pir::Type OperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT break; } parser.ConsumeToken(); - parser.lexer->Unget(peek_token_val.size() - 1); + parser.lexer->Unget(static_cast(peek_token_val.size() - 1)); if (parser.PeekToken().token_type_ != DIGIT) { break; } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 29835f84908194..899863d58aba12 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -7,7 +7,6 @@ kernel: func: add_n param: [inputs] - backward: add_n_grad - op : add_n_with_kernel args : (Tensor[] inputs) @@ -18,7 +17,6 @@ kernel: func: add_n param: [inputs] - backward: add_n_grad - op : assert args : (Tensor cond, Tensor[] data, int64_t summarize = -1) @@ -175,16 +173,25 @@ - op : write_to_array args : (Tensor i, Tensor x) output : Tensor[](out) - backward: write_to_array_grad + +- op: dpsgd + args: (Tensor param, Tensor grad, Tensor learning_rate, float clip = 10.0f, float batch_size = 16.0f, float sigma = 1.0f, int seed = 0) + output: Tensor(param_out) + infer_meta: + func: DpsgdInferMeta + kernel: + func: dpsgd + data_type: param - op: fused_attention args: (Tensor x, Tensor ln_scale, Tensor ln_bias, Tensor qkv_weight, Tensor qkv_bias, Tensor cache_kv, Tensor src_mask, Tensor out_linear_weight, Tensor out_linear_bias, Tensor ln_scale_2, Tensor ln_bias_2, int num_heads, bool transpose_qkv_wb, bool pre_layer_norm, float epsilon, float attn_dropout_rate, bool is_test, bool attn_dropout_fix_seed, int attn_dropout_seed, str attn_dropout_implementation, float dropout_rate, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon, bool add_residual, int ring_id) output: Tensor(ln_mean), Tensor(ln_var), Tensor(ln_out), Tensor(qkv_out), Tensor(qkv_bias_out), Tensor(transpose_out_2), Tensor(qk_out), Tensor(qktv_out), Tensor(softmax_out), Tensor(attn_dropout_mask_out), Tensor(attn_dropout_out), Tensor(src_mask_out), Tensor(fmha_out), Tensor(out_linear_out), Tensor(dropout_mask_out), Tensor(ln_mean_2), Tensor(ln_var_2), Tensor(bias_dropout_residual_out), Tensor(cache_kv_out), Tensor(out) kernel: func: fused_attention + data_type : x infer_meta: func: FusedAttentionInferMeta - optional: cache_kv, ln_scale, ln_bias, qkv_bias, src_mask, out_linear_bias, ln_scale_2, ln_bias_2, ln_mean_2 + optional: cache_kv, ln_scale, ln_bias, qkv_bias, src_mask, out_linear_bias, ln_scale_2, ln_bias_2, ln_mean_2, ln_var_2, bias_dropout_residual_out, cache_kv_out backward: fused_attention_grad - op: fused_feedforward @@ -192,6 +199,7 @@ output: Tensor(out), Tensor(dropout1_mask), Tensor(dropout2_mask), Tensor(ln1_mean), Tensor(ln1_variance), Tensor(ln2_mean), Tensor(ln2_variance), Tensor(linear1_out), Tensor(ln1_out), Tensor(dropout1_out), Tensor(dropout2_out) kernel: func: fused_feedforward + data_type : x infer_meta: func: FusedFeedForwardInferMeta optional: dropout1_seed, dropout2_seed, linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, ln2_mean, ln2_variance, ln1_mean, ln1_variance, ln1_out diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc index 48e69b689c4981..5452cd6f47f30e 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc @@ -222,5 +222,17 @@ const std::string& OpYamlInfoParser::GetOriginOpName() const { return std::get<4>(op_info_tuple_); } +int OpYamlInfoParser::GetTensorParamIndexByArgsName( + const std::string& args_name) const { + const auto& iter = std::find(kernel_fn_tensor_params_.begin(), + kernel_fn_tensor_params_.end(), + args_name); + if (iter != kernel_fn_tensor_params_.end()) { + return std::distance(kernel_fn_tensor_params_.begin(), iter); + } else { + return -1; + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h index 6a4bec08c2b3dc..0a972ced0ef41d 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h @@ -65,6 +65,8 @@ class OpYamlInfoParser { const std::string& GetOriginOpName() const; + int GetTensorParamIndexByArgsName(const std::string& args_name) const; + private: void parse(); inline const std::vector& InputInfo() const { diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index e95ff4b44fcb34..0aa2eaf143f7e9 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" namespace paddle { @@ -22,10 +23,11 @@ const std::unordered_set LegacyOpList = { "pd_op.load_combine", "pd_op.c_concat", "pd_op.c_broadcast_", - "pd_op.fused_bn_add_activation_", - "pd_op.fused_bn_add_activation_grad", "pd_op.c_sync_calc_stream_", "pd_op.c_sync_comm_stream_", + "pd_op.fused_gemm_epilogue", + "pd_op.fused_gemm_epilogue_grad", + "pd_op.dpsgd", "pd_op.send_v2", "pd_op.recv_v2", "pd_op.c_allreduce_sum", diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index 6da122af99716c..1c228e7e850834 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -14,14 +14,13 @@ #pragma once -// #include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/attribute.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace dialect { @@ -128,44 +127,6 @@ static inline pir::Attribute TransToIrAttribute(phi::Scalar scalar, } } -inline DataType VarTypeToDataType( - ::paddle::framework::proto::VarType_Type var_type) { - switch (var_type) { - case paddle::framework::proto::VarType_Type::VarType_Type_BOOL: - return DataType::BOOL; - case paddle::framework::proto::VarType_Type::VarType_Type_INT16: - return DataType::INT16; - case paddle::framework::proto::VarType_Type::VarType_Type_INT32: - return DataType::INT32; - case paddle::framework::proto::VarType_Type::VarType_Type_INT64: - return DataType::INT64; - case paddle::framework::proto::VarType_Type::VarType_Type_FP16: - return DataType::FLOAT16; - case paddle::framework::proto::VarType_Type::VarType_Type_FP32: - return DataType::FLOAT32; - case paddle::framework::proto::VarType_Type::VarType_Type_FP64: - return DataType::FLOAT64; - case paddle::framework::proto::VarType_Type::VarType_Type_SIZE_T: - return DataType::UINT64; - case paddle::framework::proto::VarType_Type::VarType_Type_UINT8: - return DataType::UINT8; - case paddle::framework::proto::VarType_Type::VarType_Type_INT8: - return DataType::INT8; - case paddle::framework::proto::VarType_Type::VarType_Type_BF16: - return DataType::BFLOAT16; - case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX64: - return DataType::COMPLEX64; - case paddle::framework::proto::VarType_Type::VarType_Type_COMPLEX128: - return DataType::COMPLEX128; - case paddle::framework::proto::VarType_Type::VarType_Type_PSTRING: - return DataType::PSTRING; - default: - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported proto::VarType_Type `%s` when casting it into DataType.", - var_type)); - } -} - VariantType GetAttributeData(const pir::Attribute& attr); bool IsLegacyOp(const std::string& name); diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt new file mode 100644 index 00000000000000..c1b524dda69a6a --- /dev/null +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -0,0 +1,65 @@ +file(GLOB DRR_SRCS "*.cc" "api/*.cc") + +set(op_creator_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py +) +set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) +set(op_forward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml +) +set(op_forward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml +) +set(op_backward_yaml_file1 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml +) +set(op_backward_yaml_file2 + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml +) +set(fused_op_forward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml +) +set(fused_op_backward_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml +) + +set(parsed_op_dir + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/generated) + +set(op_yaml_file3 ${parsed_op_dir}/ops.parsed.yaml) +set(op_yaml_file4 ${parsed_op_dir}/ops_backward.parsed.yaml) + +set(op_yaml_files + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${fused_op_forward_yaml_file},${fused_op_backward_yaml_file},${op_yaml_file3},${op_yaml_file4} +) + +set(op_creator_file + ${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/ir_op_factory_generated.cc) +set(op_creator_file_tmp ${op_creator_file}.tmp) + +set(dialect_name pd_op) + +add_custom_command( + OUTPUT ${op_creator_file} + COMMAND + ${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files ${op_yaml_files} + --op_compat_yaml_file ${op_compat_yaml_file} --dialect_name ${dialect_name} + --op_creator_file ${op_creator_file_tmp} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${op_creator_file_tmp} + ${op_creator_file} + COMMENT "copy_if_different ${op_creator_file}" + DEPENDS ${op_creator_gen_file} + ${op_forward_yaml_file1} + ${op_forward_yaml_file2} + ${op_backward_yaml_file1} + ${op_backward_yaml_file2} + ${op_compat_yaml_file} + ${op_yaml_file3} + ${op_yaml_file4} + pd_op_dialect_op + VERBATIM) + +cc_library( + drr + SRCS ${DRR_SRCS} ${op_creator_file} + DEPS pd_op_dialect pir) diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h new file mode 100644 index 00000000000000..d5f19ff3e6e9be --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" + +namespace pir { +namespace drr { + +template +class DrrPatternBase { + public: + virtual ~DrrPatternBase() = default; + + // Define the Drr Pattern. + virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0; + + std::unique_ptr> Build( + pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { + DrrPatternContext drr_context; + this->operator()(&drr_context); + return std::make_unique>( + drr_context, ir_context, benefit); + } +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/api/drr_pattern_context.cc new file mode 100644 index 00000000000000..5f74b986f1a5e7 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" + +#include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/pir/core/enforce.h" + +namespace pir { +namespace drr { + +DrrPatternContext::DrrPatternContext() { + source_pattern_graph_ = std::make_shared(); + result_pattern_graph_ = std::make_shared(); +} + +drr::SourcePattern DrrPatternContext::SourcePattern() { + return drr::SourcePattern(this); +} +const Op& DrrPatternContext::SourceOpPattern( + const std::string& op_type, + const std::unordered_map& attributes) { + owned_ops_.push_back(std::shared_ptr( + new drr::Op(op_type, attributes, source_pattern_graph_.get()))); + return *owned_ops_.back(); +} + +const drr::Tensor& DrrPatternContext::SourceTensorPattern( + const std::string& name) { + return source_pattern_graph_->AddTensor(std::shared_ptr( + new drr::Tensor(name, source_pattern_graph_.get()))); +} + +const Op& DrrPatternContext::ResultOpPattern( + const std::string& op_type, + const std::unordered_map& attributes) { + owned_ops_.push_back(std::shared_ptr( + new drr::Op(op_type, attributes, result_pattern_graph_.get()))); + return *owned_ops_.back(); +} + +drr::Tensor& DrrPatternContext::ResultTensorPattern(const std::string& name) { + return result_pattern_graph_->AddTensor(std::shared_ptr( + new drr::Tensor(name, result_pattern_graph_.get()))); +} + +std::vector DrrPatternContext::constraints() const { + return constraints_; +} + +// void DrrPatternContext::RequireEqual(const Attribute& first, const Attribute& +// second) { +// auto constrain_fn = [&](const MatchContext& match_context) { +// return match_context.Attr(first.id()) == match_context.Attr(second.id()); +// }; +// constraints_.emplace_back(constrain_fn); +// } + +void DrrPatternContext::RequireEqual(const TensorShape& first, + const TensorShape& second) { + // Note: we capture the datas by value for constrain_fn + // because the datas are destructed before running constrain_fn. + auto constrain_fn = [=](const MatchContext& match_context) { + return match_context.Tensor(first.tensor_name()).Shape() == + match_context.Tensor(second.tensor_name()).Shape(); + }; + constraints_.emplace_back(constrain_fn); +} + +void DrrPatternContext::RequireEqual(const TensorDataType& first, + const TensorDataType& second) { + // Note: we capture the datas by value for constrain_fn + // because the datas are destructed before running constrain_fn. + auto constrain_fn = [=](const MatchContext& match_context) { + return match_context.Tensor(first.tensor_name()).Dtype() == + match_context.Tensor(second.tensor_name()).Dtype(); + }; + constraints_.emplace_back(constrain_fn); +} + +void DrrPatternContext::RequireNativeCall( + const std::function& custom_fn) { + constraints_.emplace_back(custom_fn); +} + +void Op::operator()(const Tensor& arg, const Tensor* out) const { + std::vector inputs{&arg}; + std::vector outputs{out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); +} + +void Op::operator()(const std::vector& args, + const std::vector& outputs) const { + pattern_graph_->AddOpCall(std::make_shared(this, args, outputs)); +} + +Tensor& Op::operator()(const Tensor& arg) const { + std::vector inputs{&arg}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const { + std::vector inputs{&arg1, &arg2}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +Tensor& Op::operator()() const { + std::vector inputs{}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + +thread_local int64_t Op::count = 0; +const char* Op::prefix = "@drr_temp@_"; + +const char Tensor::NONE_TENSOR_NAME[] = "__@none_tensor@__"; + +void Tensor::Assign(const Tensor& other) { + dynamic_cast(pattern_graph_)->AssignTensor(*this, other); +} + +void Tensor::operator=(const Tensor& other) const { // NOLINT + // The two tensor must be in the same pattern graph. + IR_ENFORCE(this->pattern_graph_ == other.pattern_graph_); + if (other.name_.find(Op::prefix) == 0 && + name_.find(Op::prefix) == std::string::npos) { + other.pattern_graph_->UpdateTmpTensor(other.name_, this->name_); + } +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/api/drr_pattern_context.h new file mode 100644 index 00000000000000..b4156bd54bf414 --- /dev/null +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.h @@ -0,0 +1,334 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/drr/api/match_context.h" + +namespace pir { +namespace drr { + +class Op; +class Tensor; +class OpCall; +class SourcePattern; +class ResultPattern; +class PatternGraph; +class SourcePatternGraph; +class ResultPatternGraph; + +class NormalAttribute { + public: + explicit NormalAttribute(const std::string& name) : attr_name_(name) {} + + const std::string& name() const { return attr_name_; } + + private: + std::string attr_name_; +}; + +using AttrComputeFunc = std::function; + +class ComputeAttribute { + public: + explicit ComputeAttribute(const AttrComputeFunc& attr_compute_func) + : attr_compute_func_(attr_compute_func) {} + + const AttrComputeFunc& attr_compute_func() const { + return attr_compute_func_; + } + + private: + AttrComputeFunc attr_compute_func_; +}; + +using Attribute = std::variant; + +class TensorShape { + public: + explicit TensorShape(const std::string& tensor_name) + : tensor_name_(tensor_name) {} + + const std::string& tensor_name() const { return tensor_name_; } + + private: + std::string tensor_name_; +}; + +class TensorDataType { + public: + explicit TensorDataType(const std::string& tensor_name) + : tensor_name_(tensor_name) {} + + const std::string& tensor_name() const { return tensor_name_; } + + private: + std::string tensor_name_; +}; + +class Constraint { + public: + explicit Constraint( + const std::function& constrain_fn) + : IsContextMatchConstraint_(constrain_fn) {} + bool operator()(const MatchContext& match_context) const { + return IsContextMatchConstraint_(match_context); + } + + private: + std::function IsContextMatchConstraint_; +}; + +class DrrPatternContext { + public: + DrrPatternContext(); + ~DrrPatternContext() = default; + + drr::SourcePattern SourcePattern(); + + std::shared_ptr source_pattern_graph() const { + return source_pattern_graph_; + } + + std::vector constraints() const; + + std::shared_ptr result_pattern_graph() const { + return result_pattern_graph_; + } + + private: + friend class drr::SourcePattern; + friend class drr::ResultPattern; + + const Op& SourceOpPattern( + const std::string& op_type, + const std::unordered_map& attributes = {}); + const drr::Tensor& SourceTensorPattern(const std::string& name); + + const Op& ResultOpPattern( + const std::string& op_type, + const std::unordered_map& attributes = {}); + drr::Tensor& ResultTensorPattern(const std::string& name); + + // void RequireEqual(const Attribute& first, const Attribute& second); + void RequireEqual(const TensorShape& first, const TensorShape& second); + void RequireEqual(const TensorDataType& first, const TensorDataType& second); + void RequireNativeCall( + const std::function& custom_fn); + + std::shared_ptr source_pattern_graph_; + std::vector constraints_; + std::shared_ptr result_pattern_graph_; + + std::vector> owned_ops_; +}; + +class Op { + public: + const std::string& name() const { return op_type_name_; } + + void operator()(const Tensor& arg, const Tensor* out) const; + + Tensor& operator()() const; + + Tensor& operator()(const Tensor& arg) const; + Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; + void operator()(const std::vector& args, + const std::vector& outputs) const; + // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const + // Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const + // Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor& + // operator()(const Tensor& arg0, const Tensor& arg1, const Tensor& arg2, + // const Tensor& arg3, const Tensor& arg4) const; + + static const char* prefix; + + private: + friend class DrrPatternContext; + friend class OpCall; + + Op(const std::string& op_type_name, + const std::unordered_map& attributes, + PatternGraph* pattern_graph) + : op_type_name_(op_type_name), + attributes_(attributes), + pattern_graph_(pattern_graph) {} + + const std::unordered_map& attributes() const { + return attributes_; + } + + thread_local static int64_t count; + + std::string op_type_name_; + std::unordered_map attributes_; + PatternGraph* pattern_graph_{nullptr}; +}; + +class Tensor { + public: + static const char NONE_TENSOR_NAME[]; + + const std::string& DebugName() const; + + TensorShape shape() const { return TensorShape(name()); } + + TensorDataType dtype() const { return TensorDataType(name()); } + + bool is_none() const { return name_ == NONE_TENSOR_NAME; } + + void Assign(const Tensor& other); + + void operator=(const Tensor& other) const; // NOLINT + + const std::string& name() const { return name_; } + + void set_name(const std::string& name) { name_ = name; } + + OpCall* producer() const { return producer_; } + + void set_producer(OpCall* producer) { producer_ = producer; } + + const std::vector& consumers() const { return consumers_; } + + void set_consumables(const std::vector& consumers) { + consumers_ = consumers; + } + + void AddConsumer(const OpCall* consumer) { consumers_.push_back(consumer); } + + private: + friend class DrrPatternContext; + friend class Op; + + Tensor(const std::string& name, PatternGraph* pattern_graph) + : name_(name), pattern_graph_(pattern_graph) {} + + std::string name_; + OpCall* producer_{nullptr}; + std::vector consumers_; + PatternGraph* pattern_graph_{nullptr}; +}; + +class OpCall { + public: + OpCall(const Op* op, + const std::vector& inputs, + const std::vector& outputs) + : op_name_(op->op_type_name_), + inputs_(inputs), + outputs_(outputs), + attributes_(op->attributes_) {} + + const std::string& name() const { return op_name_; } + + const std::vector& inputs() const { return inputs_; } + + const std::vector& outputs() const { return outputs_; } + + const std::unordered_map& attributes() const { + return attributes_; + } + + private: + std::string op_name_; + std::vector inputs_; + std::vector outputs_; + std::unordered_map attributes_; +}; + +class ResultPattern { + public: + const drr::Op& Op( + const std::string& op_type, + const std::unordered_map& attributes = {}) { + return ctx_->ResultOpPattern(op_type, attributes); + } + + drr::Tensor& Tensor(const std::string& name) { + return ctx_->ResultTensorPattern(name); + } + + // Represent the input tensor which is none. + // Example: + // instance_norm has follow input tensor : (x, scale, bias), scale and + // bias are optional(means it may be none). + // When scale is onoe, we can write a instance_norm op in drr as follow: + // res.Op("instance_norm")(res.Tensor("x"), res.NoneTensor, + // res.Tensor("bias")); + drr::Tensor& NoneTensor() { + return ctx_->ResultTensorPattern(Tensor::NONE_TENSOR_NAME); + } + + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + Attribute Attr(const AttrComputeFunc& attr_compute_func) const { + return ComputeAttribute(attr_compute_func); + } + + private: + friend class SourcePattern; + + explicit ResultPattern(DrrPatternContext* ctx) : ctx_(ctx) {} + + DrrPatternContext* ctx_{nullptr}; +}; + +class SourcePattern { + public: + drr::ResultPattern ResultPattern() const { return drr::ResultPattern(ctx_); } + + const drr::Op& Op( + const std::string& op_type, + const std::unordered_map& attributes = {}) { + return ctx_->SourceOpPattern(op_type, attributes); + } + + const drr::Tensor& Tensor(const std::string& name) { + return ctx_->SourceTensorPattern(name); + } + + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + + void RequireEqual(const TensorShape& first, const TensorShape& second) { + ctx_->RequireEqual(first, second); + } + void RequireEqual(const TensorDataType& first, const TensorDataType& second) { + ctx_->RequireEqual(first, second); + } + + void RequireNativeCall( + const std::function& custom_fn) { + ctx_->RequireNativeCall(custom_fn); + } + + private: + friend class DrrPatternContext; + explicit SourcePattern(DrrPatternContext* ctx) : ctx_(ctx) {} + DrrPatternContext* ctx_{nullptr}; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/api/match_context.cc new file mode 100644 index 00000000000000..35b28db13254ed --- /dev/null +++ b/paddle/fluid/pir/drr/api/match_context.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/match_context.h" + +#include + +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" + +namespace pir { +namespace drr { + +MatchContext::MatchContext(std::shared_ptr impl) + : impl_(impl) {} + +const TensorInterface& MatchContext::Tensor( + const std::string& tensor_name) const { + return impl_->Tensor(tensor_name); +} + +template +T MatchContext::Attr(const std::string& attr_name) const { + return impl_->Attr(attr_name); +} + +template bool MatchContext::Attr(const std::string&) const; +template int32_t MatchContext::Attr(const std::string&) const; +template int64_t MatchContext::Attr(const std::string&) const; +template float MatchContext::Attr(const std::string&) const; +template std::string MatchContext::Attr(const std::string&) const; +template std::vector MatchContext::Attr>( + const std::string&) const; +template std::vector MatchContext::Attr>( + const std::string&) const; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/api/match_context.h new file mode 100644 index 00000000000000..a1699ccb5bddf6 --- /dev/null +++ b/paddle/fluid/pir/drr/api/match_context.h @@ -0,0 +1,43 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/ir_operation.h" + +namespace pir { +namespace drr { + +class TensorInterface; +class MatchContextImpl; + +class MatchContext final { + public: + MatchContext(std::shared_ptr impl); + + const TensorInterface& Tensor(const std::string& tensor_name) const; + + template + T Attr(const std::string& attr_name) const; + + private: + std::shared_ptr impl_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc new file mode 100644 index 00000000000000..1b81b3a5672117 --- /dev/null +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/ir_value.h" + +namespace pir { +namespace drr { + +bool ShapeInterface::operator==(const ShapeInterface& other) const { + return *shape_ == *other.shape_; +} + +int ShapeInterface::size() const { return shape_->size(); } + +int64_t ShapeInterface::at(int idx) const { return shape_->at(idx); } + +bool DtypeInterface::operator==(const DtypeInterface& other) const { + return *dtype_ == *other.dtype_; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h new file mode 100644 index 00000000000000..7629857591bf33 --- /dev/null +++ b/paddle/fluid/pir/drr/api/tensor_interface.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace pir { +namespace drr { + +class IrValue; +class IrShape; +class IrDtype; + +class ShapeInterface final { + public: + bool operator==(const ShapeInterface& other) const; + + int size() const; + + int64_t at(int idx) const; + + private: + explicit ShapeInterface(const IrShape* shape) : shape_(shape) {} + + friend class IrValue; + + const IrShape* shape_; +}; + +class DtypeInterface final { + public: + bool operator==(const DtypeInterface& other) const; + + private: + explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {} + + friend class IrValue; + + const IrDtype* dtype_; +}; + +class TensorInterface { + public: + virtual ShapeInterface Shape() const = 0; + virtual DtypeInterface Dtype() const = 0; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h new file mode 100644 index 00000000000000..fb989fe063b771 --- /dev/null +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -0,0 +1,116 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace pir { +namespace drr { + +template +struct CppTypeToIrAttribute; + +#define PD_SPECIALIZE_CppTypeToIrAttribute(cpp_type, ir_attr_type) \ + template <> \ + struct CppTypeToIrAttribute< \ + std::remove_const_t>> { \ + using type = ir_attr_type; \ + }; + +PD_SPECIALIZE_CppTypeToIrAttribute(bool, BoolAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); +PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::string, StrAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, + paddle::dialect::DataTypeAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, + paddle::dialect::IntArrayAttribute); + +template +struct IrAttrbuteCreator { + typename CppTypeToIrAttribute::type operator()(T obj) const { + return CppTypeToIrAttribute::type::template get( + pir::IrContext::Instance(), obj); + } +}; + +template <> +struct IrAttrbuteCreator> { + pir::ArrayAttribute operator()(std::vector obj) const { + std::vector attr_vec; + attr_vec.reserve(obj.size()); + for (int32_t x : obj) { + attr_vec.push_back(Int32Attribute::get(pir::IrContext::Instance(), x)); + } + return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); + } +}; + +template +struct IrAttrTypeCast { + static T To(const pir::Attribute& attr) { + return attr.dyn_cast::type>().data(); + } +}; + +template <> +struct IrAttrTypeCast { + static std::string To(const pir::Attribute& attr) { + return attr.dyn_cast::type>() + .AsString(); + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back(array_attr.at(i).dyn_cast().data()); + } + return result; + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + if (attr.dyn_cast()) { + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back( + array_attr.at(i).dyn_cast().data()); + } + } else if (attr.dyn_cast()) { + result = + attr.dyn_cast().data().GetData(); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "Dynamic cast failed for IR attribute vector")); + } + return result; + } +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/drr_rewrite_pattern.h new file mode 100644 index 00000000000000..c17feb0eaad052 --- /dev/null +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.h @@ -0,0 +1,568 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/ir_operation_factory.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/type_name.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" + +namespace pir { +namespace drr { + +template +class DrrRewritePattern : public pir::RewritePattern { + public: + explicit DrrRewritePattern(const DrrPatternContext& drr_context, + pir::IrContext* context, + pir::PatternBenefit benefit = 1) + : pir::RewritePattern( + drr_context.source_pattern_graph()->AnchorNode()->name(), + benefit, + context, + {}), + source_pattern_graph_(drr_context.source_pattern_graph()), + constraints_(drr_context.constraints()), + result_pattern_graph_(drr_context.result_pattern_graph()) { + IR_ENFORCE(!source_pattern_graph_->owned_op_call().empty(), + "source_pattern_graph is empty, please check the drr pattern " + "define code."); + } + + bool MatchAndRewrite(pir::Operation* op, + PatternRewriter& rewriter) const override { // NOLINT + std::shared_ptr src_match_ctx = + std::make_shared(); + if (PatternGraphMatch(op, src_match_ctx.get())) { + VLOG(4) << "DRR pattern (" << pir::get_type_name() + << ") is matched in program."; + PatternGraphRewrite(*src_match_ctx, rewriter); + return true; + } + return false; + } + + private: + bool PatternGraphMatch(pir::Operation* op, + MatchContextImpl* source_pattern_match_ctx) const { + VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")"; + const OpCall* anchor = source_pattern_graph_->AnchorNode(); + std::unordered_map> + bind_map = + FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get())); + if (bind_map.empty()) { + return false; + } + std::vector drr_output_sequence; + std::vector ir_output_sequence; + std::unordered_map output_op_map; + for (auto pair : bind_map) { + drr_output_sequence.push_back(pair.first); + } + // using dfs to obtain the arrangement of all candidate ir ops + auto permute = [&](auto&& permute, size_t index) -> bool { + if (index == drr_output_sequence.size()) { + // avoiding duplicate binding of ir op + std::unordered_set ir_output_set; + for (Operation* op : ir_output_sequence) { + auto pr = ir_output_set.insert(op); + if (pr.second == false) { + return false; + } + } + // new match_ctx + std::shared_ptr match_ctx = + std::make_shared(); + std::transform(drr_output_sequence.begin(), + drr_output_sequence.end(), + ir_output_sequence.begin(), + std::inserter(output_op_map, output_op_map.end()), + [](const OpCall* drr_op, Operation* ir_op) { + return std::make_pair(drr_op, ir_op); + }); + if (MatchFromOutputToInput( + output_op_map, *(source_pattern_graph_.get()), match_ctx)) { + *source_pattern_match_ctx = *match_ctx; + return true; + } + return false; + } + for (auto* ir_op : bind_map[drr_output_sequence[index]]) { + ir_output_sequence.push_back(ir_op); + if (permute(permute, index + 1)) { + return true; + } + ir_output_sequence.pop_back(); + } + return false; + }; + + return permute(permute, 0); + } + + std::unordered_map> + FindCandidateIrOutputOp( + pir::Operation* op, + const OpCall* anchor, + const SourcePatternGraph& source_pattern_graph) const { + // get source pattern output op + std::unordered_set drr_output_op_set = + source_pattern_graph.OutputNodes(); + std::unordered_map> + output_op_bind_map{{anchor, {op}}}; + if (drr_output_op_set.size() == 1) { + return output_op_bind_map; + } + std::unordered_set drr_visited_ops{anchor}; + DfsVisitor( + anchor, op, drr_output_op_set, &drr_visited_ops, &output_op_bind_map); + if (output_op_bind_map.size() != drr_output_op_set.size()) { + return {}; + } + return output_op_bind_map; + } + + void DfsVisitor( + const OpCall* drr_op, + pir::Operation* ir_op, + const std::unordered_set& drr_output_op_set, + std::unordered_set* drr_visited_ops, + std::unordered_map>* + output_op_bind_map) const { + VLOG(6) << "DfsVisitor Start: drr op(" << drr_op->name() << ")" + << "ir op(" << ir_op->name() << ")"; + if (drr_op->name() != ir_op->name()) { + return; + } + // check op input's size + const auto& drr_op_input_tensors = drr_op->inputs(); + auto ir_op_input_value_size = ir_op->num_operands(); + if (drr_op_input_tensors.size() != ir_op_input_value_size) { + return; + } + // check op output's size + const auto& drr_op_output_tensors = drr_op->outputs(); + auto ir_op_output_value_size = ir_op->num_results(); + if (drr_op_output_tensors.size() != ir_op_output_value_size) { + return; + } + // check producer op + for (size_t i = 0; i < drr_op_input_tensors.size(); ++i) { + // case 1: drr_op_input_tensor is the input tensor of source pattern + if (drr_op_input_tensors[i]->producer() == nullptr) { + // dfs source pattern input tensor other child op + auto ir_input_tensor = ir_op->operand(i).source(); + for (auto drr_bro_op : drr_op_input_tensors[i]->consumers()) { + if (drr_visited_ops->count(drr_bro_op)) { + continue; + } + for (auto it = ir_input_tensor.use_begin(); + it != ir_input_tensor.use_end(); + ++it) { + auto* ir_bro_op = it.owner(); + if (drr_bro_op->name() == ir_bro_op->name()) { + drr_visited_ops->insert(drr_bro_op); + DfsVisitor(drr_bro_op, + ir_bro_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_bro_op); + } + } + } + continue; + } + // case 2: have producer op + const auto& drr_producer_op = drr_op_input_tensors[i]->producer(); + if (drr_visited_ops->count(drr_producer_op)) { + continue; + } + auto ir_operand_value = ir_op->operand(i).source(); + if (drr_op_input_tensors[i]->consumers().size() != + ir_operand_value.use_count()) { + return; + } + auto* ir_producer_op = ir_operand_value.dyn_cast().owner(); + drr_visited_ops->insert(drr_producer_op); + DfsVisitor(drr_producer_op, + ir_producer_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_producer_op); + } + if (drr_output_op_set.count(drr_op)) { + (*output_op_bind_map)[drr_op].insert(ir_op); + return; + } + // check child ops + for (size_t i = 0; i < drr_op_output_tensors.size(); ++i) { + const auto& drr_child_ops = drr_op_output_tensors[i]->consumers(); + auto ir_output_value = ir_op->result(i); + if (drr_child_ops.size() != ir_output_value.use_count()) { + return; + } + for (auto* drr_child_op : drr_child_ops) { + for (auto it = ir_output_value.use_begin(); + it != ir_output_value.use_end(); + ++it) { + auto* ir_child_op = it.owner(); + if (drr_child_op->name() == ir_child_op->name()) { + if (drr_visited_ops->count(drr_child_op)) { + continue; + } + drr_visited_ops->insert(drr_child_op); + DfsVisitor(drr_child_op, + ir_child_op, + drr_output_op_set, + drr_visited_ops, + output_op_bind_map); + drr_visited_ops->erase(drr_child_op); + } + } + } + } // check child ops + return; + } + + bool MatchFromOutputToInput( + std::unordered_map output_op_map, + const SourcePatternGraph& source_pattern_graph, + const std::shared_ptr& source_pattern_match_ctx) const { + VLOG(6) << "MatchFromOutputToInput Start"; + std::unordered_set drr_visited; + std::unordered_set ir_visited; + std::queue drr_q; + std::queue ir_q; + bool matched = true; + size_t step = 0; + for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { + VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" + << it->second << ") in source_pattern_graph "; + drr_q.push(it->first); + drr_visited.insert(it->first); + ir_q.push(it->second); + ir_visited.insert(it->second); + } + while (!drr_q.empty()) { + if (!matched) break; + auto* drr_node = drr_q.front(); + auto* ir_node = ir_q.front(); + drr_q.pop(); + ir_q.pop(); + if (drr_node->name() != ir_node->name()) { + matched = false; + break; + } + const auto& drr_input_tensors = drr_node->inputs(); + auto ir_input_value_size = ir_node->num_operands(); + if (drr_input_tensors.size() != ir_input_value_size) { + matched = false; + break; + } + if (drr_node->outputs().size() != ir_node->num_results()) { + matched = false; + break; + } + source_pattern_match_ctx->BindIrOperation( + drr_node, std::make_shared(ir_node)); + // binding input_tensor of current_op + for (size_t i = 0; i < drr_input_tensors.size(); ++i) { + source_pattern_match_ctx->BindIrValue( + drr_input_tensors[i]->name(), + std::make_shared(ir_node->operand(i).source())); + auto* drr_producer_op = drr_input_tensors[i]->producer(); + if (drr_producer_op == nullptr) { + continue; + } + auto* ir_producer_op = + ir_node->operand(i).source().dyn_cast().owner(); + if (drr_input_tensors[i]->consumers().size() != + ir_node->operand(i).source().use_count()) { + matched = false; + break; + } + // bfs producer_op of current_op + if (!drr_visited.count(drr_producer_op)) { + drr_q.push(drr_producer_op); + ir_q.push(ir_producer_op); + drr_visited.insert(drr_producer_op); + ir_visited.insert(ir_producer_op); + } + } + // binding output tensor of current_op + auto drr_op_output_tensor = drr_node->outputs(); + for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { + source_pattern_match_ctx->BindIrValue( + drr_op_output_tensor[j]->name(), + std::make_shared(ir_node->result(j))); + } + ++step; + } + + if (matched) { + IR_ENFORCE(step == source_pattern_graph.CountOfOpCalls()); + } else { + return matched; + } + + MatchContext match_context{source_pattern_match_ctx}; + for (const auto& constraint : constraints_) { + matched = constraint(match_context); + if (!matched) break; + } + + return matched; + } + + void PatternGraphRewrite(const MatchContextImpl& source_pattern_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + VLOG(6) << "Create Operations in result_pattern_graph"; + MatchContextImpl res_match_ctx = CreateOperations(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); + VLOG(6) << "Process Assign Tensor"; + RebindIrTensorForAssignTensor(*result_pattern_graph_, &res_match_ctx); + VLOG(6) << "Replace Output Values in source_pattern_graph by Output Values " + "in result_pattern_graph"; + ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter); + VLOG(6) << "Delete Operations in source_pattern_graph"; + DeleteSourcePatternOp(*source_pattern_graph_, + *result_pattern_graph_, + source_pattern_match_ctx, + rewriter); + } + + private: + MatchContextImpl CreateOperations( + const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + MatchContextImpl res_match_ctx; + // add input tensors info for res_match_ctx + for (const auto& in_tensor : result_pattern_graph.input_tensors()) { + IR_ENFORCE(result_pattern_graph.id2owend_tensor().count(in_tensor), + "Drr input tensor [%s] must exists in result pattern graph.", + in_tensor); + if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { + res_match_ctx.BindIrValue( + in_tensor, + std::make_shared(src_match_ctx.GetIrValue(in_tensor))); + } + } + + if (result_pattern_graph.CountOfOpCalls() == 1) { + CreateOperation(*result_pattern_graph.owned_op_call()[0], + src_match_ctx, + rewriter, + &res_match_ctx); + return res_match_ctx; + } + + std::vector> temp_program; + std::unordered_map op_2_temp_program_index; + for (Operation* op : *rewriter.block()) { + op_2_temp_program_index[op] = temp_program.size(); + temp_program.push_back({op}); + } + + // topo order visit result_pattern_graph + GraphTopo graph_topo_visit(&result_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { + // set insert point + size_t max_input_op_index = 0; + Operation* max_index_op = nullptr; + for (const Tensor* input : op_call.inputs()) { + if (input->is_none()) { + continue; + } + Value ir_val = res_match_ctx.GetIrValue(input->name()).get(); + if (ir_val) { + Operation* ir_input_op = ir_val.dyn_cast().owner(); + if (max_input_op_index < op_2_temp_program_index[ir_input_op]) { + max_input_op_index = op_2_temp_program_index[ir_input_op]; + max_index_op = ir_input_op; + } else if (max_input_op_index == + op_2_temp_program_index[ir_input_op]) { + const auto& ops_vec = temp_program[max_input_op_index]; + for (auto it = ops_vec.rbegin(); it != ops_vec.rend(); it++) { + if (*it == max_index_op) { + break; + } else if (*it == ir_input_op) { + max_index_op = ir_input_op; + break; + } else { + // do nothing + } + } + } else { + // do nothing + } + } + } + if (max_input_op_index == 0UL) { + VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; + Operation* source_patter_first_op = + src_match_ctx + .Operation(source_pattern_graph.owned_op_call()[0].get()) + .get(); + max_input_op_index = op_2_temp_program_index[source_patter_first_op]; + rewriter.SetInsertionPoint(source_patter_first_op); + } else { + rewriter.SetInsertionPointAfter(max_index_op); + } + + Operation* new_op = + CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); + op_2_temp_program_index[new_op] = max_input_op_index + 1; + temp_program[max_input_op_index + 1].push_back(new_op); + }); + + return res_match_ctx; + } + + void RebindIrTensorForAssignTensor( + const ResultPatternGraph& result_pattern_graph, + MatchContextImpl* res_match_ctx) const { + const auto& tensor_assign_map = result_pattern_graph.tensor_assign_map(); + for (const auto& kv : tensor_assign_map) { + const auto& src_tensor_name = kv.first; + const auto& dst_tensor_name = kv.second; + res_match_ctx->BindIrValue( + src_tensor_name, + std::make_shared( + res_match_ctx->GetIrValue(dst_tensor_name))); + } + } + + void ReplaceOutputTensor(const MatchContextImpl& src_match_ctx, + const MatchContextImpl& res_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + for (const auto& output_name : result_pattern_graph_->output_tensors()) { + if (source_pattern_graph_->id2owend_tensor().count(output_name)) { + const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); + const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); + rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + } else { + LOG(WARNING) << "The output tensor (" << output_name + << ") in the result_pattern_graph is not the tensor" + " in source_pattern_graph."; + } + } + } + + void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph, + const ResultPatternGraph& result_pattern_graph, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter) const { // NOLINT + std::vector topo_order_ops; + GraphTopo graph_topo_visit(&source_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder( + [&topo_order_ops](const OpCall& op_call) { + topo_order_ops.push_back(&op_call); + }); + + // Filter the operations which are replaced by result pattern + // 1. Filter operations by forward walk + std::unordered_set forward_visited_tensor_set( + result_pattern_graph.input_tensors()); + std::unordered_set forward_deleted_ops; + std::for_each(topo_order_ops.begin(), + topo_order_ops.end(), + [&forward_deleted_ops, + &forward_visited_tensor_set](const OpCall* op_call) { + if (op_call->inputs().empty()) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + } + for (const auto* input : op_call->inputs()) { + if (forward_visited_tensor_set.count(input->name())) { + forward_deleted_ops.insert(op_call); + for (const auto* output : op_call->outputs()) { + forward_visited_tensor_set.insert(output->name()); + } + break; + } + } + }); + // 2. Filter operations by backward walk and merge the forward result + std::unordered_set backward_visited_tensor_set( + result_pattern_graph.output_tensors()); + std::vector deleted_ops; + std::unordered_set deleted_ops_set; + std::for_each(topo_order_ops.rbegin(), + topo_order_ops.rend(), + [&deleted_ops, + &deleted_ops_set, + &backward_visited_tensor_set, + &forward_deleted_ops](const OpCall* op_call) { + bool all_comsumer_deleted = true; + bool from_backward_visited_tensor = false; + for (const auto* output : op_call->outputs()) { + if (backward_visited_tensor_set.count(output->name())) { + from_backward_visited_tensor = true; + } else if (output->consumers().empty()) { + continue; + } else { + all_comsumer_deleted = false; + } + } + if (all_comsumer_deleted && from_backward_visited_tensor && + forward_deleted_ops.count(op_call)) { + deleted_ops_set.insert(op_call); + deleted_ops.push_back(op_call); + for (const auto* input : op_call->inputs()) { + backward_visited_tensor_set.insert(input->name()); + } + } + }); + + // Delete Operation with topo order from output tensors. + for (const auto* op_call : deleted_ops) { + IR_ENFORCE(src_match_ctx.operation_map().count(op_call), + "Drr OpCall [%s] must exists in match context.", + op_call->name()); + auto* op = src_match_ctx.operation_map().at(op_call)->get(); + VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" << op + << ") in source_pattern_graph "; + rewriter.EraseOp(op); + } + } + + private: + const std::shared_ptr source_pattern_graph_; + const std::vector constraints_; + const std::shared_ptr result_pattern_graph_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation.h b/paddle/fluid/pir/drr/ir_operation.h new file mode 100644 index 00000000000000..2764bc92454170 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/operation.h" + +namespace pir { +namespace drr { + +class IrOperation { + public: + explicit IrOperation(pir::Operation* op) : op_(op) {} + + pir::Operation* get() const { return op_; } + + private: + pir::Operation* op_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc new file mode 100644 index 00000000000000..5355a8977e8c53 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -0,0 +1,166 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/ir_operation_factory.h" + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/value.h" + +namespace pir { +namespace drr { + +void OperationFactory::RegisterManualOpCreator() { + RegisterOperationCreator( + "pd_op.fused_gemm_epilogue", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + inputs[2].dyn_cast(), + attrs); + }); + RegisterOperationCreator( + "pd_op.fused_gemm_epilogue_grad", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + inputs[2].dyn_cast(), + inputs[3].dyn_cast(), + attrs); + }); + RegisterOperationCreator("builtin.combine", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build(inputs); + }); +} + +static pir::Attribute CreateIrAttribute(const std::any& obj) { + if (obj.type() == typeid(bool)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int32_t)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int64_t)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(float)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::string)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(const char*)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(phi::DataType)) { + return IrAttrbuteCreator()( + std::any_cast(obj)); + } else if (obj.type() == typeid(phi::Place)) { + return IrAttrbuteCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Type error. CreateIrAttribute for type(%s) " + "is unimplemented CreateInCurrently.", + obj.type().name())); + } +} + +pir::AttributeMap CreateAttributeMap(const OpCall& op_call, + const MatchContextImpl& src_match_ctx) { + pir::AttributeMap attr_map; + for (const auto& kv : op_call.attributes()) { + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + attr_map[kv.first] = src_match_ctx.GetIrAttr(arg.name()); + } + if constexpr (std::is_same_v, + ComputeAttribute>) { + MatchContext ctx(std::make_shared(src_match_ctx)); + attr_map[kv.first] = + CreateIrAttribute(arg.attr_compute_func()(ctx)); + } + }, + kv.second); + } + return attr_map; +} + +Value GetIrValueByDrrTensor(const Tensor& tensor, + const MatchContextImpl& res_match_ctx) { + if (tensor.is_none()) { + return Value{}; + } + return res_match_ctx.GetIrValue(tensor.name()).get(); +} + +std::vector GetIrValuesByDrrTensors( + const std::vector& tensors, + const MatchContextImpl& res_match_ctx) { + std::vector ir_values; + ir_values.reserve(tensors.size()); + for (const auto* tensor : tensors) { + ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); + } + return ir_values; +} + +void BindIrOutputs(const OpCall& op_call, + pir::Operation* op, + MatchContextImpl* match_ctx) { + for (size_t i = 0; i < op_call.outputs().size(); ++i) { + std::shared_ptr ir_value = nullptr; + if (op->result(i)) { + ir_value = std::make_shared(op->result(i)); + } + match_ctx->BindIrValue(op_call.outputs()[i]->name(), ir_value); + } +} + +pir::Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx) { + VLOG(6) << "Drr create [" << op_call.name() << "] op..."; + const auto& inputs = op_call.inputs(); + std::vector ir_values = + GetIrValuesByDrrTensors(inputs, *res_match_ctx); + pir::Operation* op = OperationFactory::Instance().CreateOperation( + op_call.name(), + ir_values, + CreateAttributeMap(op_call, src_match_ctx), + rewriter); + BindIrOutputs(op_call, op, res_match_ctx); + VLOG(6) << "Drr create [" << op_call.name() << "] op done."; + return op; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h new file mode 100644 index 00000000000000..b38b5cd6a12b32 --- /dev/null +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" + +namespace pir { +namespace drr { + +class OperationFactory { + public: + static OperationFactory& Instance() { + static OperationFactory operation_factory; + return operation_factory; + } + + using operation_create_fn = + std::function&, + const pir::AttributeMap&, + pir::PatternRewriter&)>; + + void RegisterOperationCreator(const std::string& op_name, + const operation_create_fn& create_fn) { + op_creator_map.emplace(op_name, create_fn); + } + + pir::Operation* CreateOperation( + const std::string& op_name, + const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) const { // NOLINT + auto iter = op_creator_map.find(op_name); + IR_ENFORCE(iter != op_creator_map.end(), + "The create function for op: (%s) is not found.", + op_name); + return iter->second(inputs, attrs, rewriter); + } + + private: + OperationFactory() { + RegisterGeneratedOpCreator(); + RegisterManualOpCreator(); + } + + void RegisterManualOpCreator(); + void RegisterGeneratedOpCreator(); + + std::unordered_map op_creator_map; +}; + +pir::Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + pir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx); + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/ir_value.h new file mode 100644 index 00000000000000..907df9dfd24ebc --- /dev/null +++ b/paddle/fluid/pir/drr/ir_value.h @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/value.h" + +namespace pir { +namespace drr { + +class IrShape { + public: + explicit IrShape(const phi::DDim& dims) : dims_(dims) {} + + bool operator==(const IrShape& other) const { return dims_ == other.dims_; } + + int size() const { return dims_.size(); } + + int64_t at(int idx) const { return dims_.at(idx); } + + private: + const phi::DDim dims_; +}; + +class IrDtype { + public: + explicit IrDtype(pir::Type dtype) : dtype_(dtype) {} + + bool operator==(IrDtype other) const { return dtype_ == other.dtype_; } + + private: + const pir::Type dtype_; +}; + +class IrValue : public TensorInterface { + public: + explicit IrValue(const pir::Value& value) + : value_(value), + shape_((value && value.type() && + value.type().dyn_cast()) + ? value.type() + .dyn_cast() + .dims() + : phi::DDim{}), + dtype_((value && value.type() && + value.type().dyn_cast()) + ? value.type() + .dyn_cast() + .dtype() + : pir::Type{}) {} + + ShapeInterface Shape() const override { return ShapeInterface(&shape_); } + DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); } + + const Value& get() const { return value_; } + + private: + const Value value_; + const IrShape shape_; + const IrDtype dtype_; +}; + +class IrAttr; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h new file mode 100644 index 00000000000000..a04efbbfaf444b --- /dev/null +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -0,0 +1,124 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/tensor_interface.h" +#include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/ir_value.h" +#include "paddle/pir/core/builtin_attribute.h" + +namespace pir { +namespace drr { + +class MatchContextImpl final { + public: + MatchContextImpl() = default; + ~MatchContextImpl() = default; + + const TensorInterface& Tensor(const std::string& tensor_name) const { + IR_ENFORCE(tensor_map_.count(tensor_name), + "Drr tensor [%s] must exists in pattern graph.", + tensor_name); + return *tensor_map_.at(tensor_name); + } + + const IrOperation& Operation(const OpCall* op_call) const { + IR_ENFORCE(operation_map_.count(op_call), + "Drr operation [%s] must exists in pattern graph.", + op_call->name()); + return *operation_map_.at(op_call); + } + + template + T Attr(const std::string& attr_name) const { + return IrAttrTypeCast::To(GetIrAttr(attr_name)); + } + + const IrValue& GetIrValue(const std::string& tensor_name) const { + auto iter = tensor_map_.find(tensor_name); + PADDLE_ENFORCE_NE( + iter, + tensor_map_.end(), + phi::errors::OutOfRange( + "the drr tensor(%s) is not found in the map to ir value.", + tensor_name)); + return *iter->second; + } + + pir::Attribute GetIrAttr(const std::string& attr_name) const { + auto iter = attr_map_.find(attr_name); + PADDLE_ENFORCE_NE( + iter, + attr_map_.end(), + phi::errors::OutOfRange( + "the drr attr(%s) is not found in the map to ir attribute.", + attr_name)); + return iter->second; + } + + const std::unordered_map>& + operation_map() const { + return operation_map_; + } + + const std::unordered_map& attr_map() const { + return attr_map_; + } + + const std::unordered_map>& tensor_map() + const { + return tensor_map_; + } + + void BindIrValue(const std::string& value_name, + const std::shared_ptr& value) { + tensor_map_.emplace(value_name, value); + } + + void BindIrOperation(const OpCall* op_call, + const std::shared_ptr& op) { + operation_map_.emplace(op_call, op); + const auto& attrs = op_call->attributes(); + for (const auto& kv : attrs) { + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + } + }, + kv.second); + } + } + + private: + void BindIrAttr(const std::string& attr_name, pir::Attribute attr) { + attr_map_.emplace(attr_name, attr); + } + + std::unordered_map> tensor_map_; + std::unordered_map> + operation_map_; + std::unordered_map attr_map_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc new file mode 100644 index 00000000000000..0b63f398a790bd --- /dev/null +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -0,0 +1,223 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/drr/pattern_graph.h" + +#include + +#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/pir/core/enforce.h" + +namespace pir { +namespace drr { + +const drr::OpCall &PatternGraph::AddOpCall( + const std::shared_ptr &op_call) { + owned_op_call_.push_back(op_call); + for (const auto *input : op_call->inputs()) { + const auto &tensor_name = input->name(); + IR_ENFORCE(id2owned_tensor_.count(tensor_name), + "intput tensor [%s] not exist.", + tensor_name); + id2owned_tensor_.at(tensor_name)->AddConsumer(op_call.get()); + + if (input->producer() == nullptr) { + input_tensors_.insert(tensor_name); + } + if (output_tensors_.find(tensor_name) != output_tensors_.end()) { + output_tensors_.erase(tensor_name); + } + } + for (auto &output : op_call->outputs()) { + const auto &out_tensor_name = output->name(); + IR_ENFORCE(id2owned_tensor_.count(out_tensor_name)); + id2owned_tensor_[output->name()]->set_producer(op_call.get()); + } + return *owned_op_call_.back(); +} + +drr::Tensor &PatternGraph::AddTensor( + const std::shared_ptr &tensor) { + if (id2owned_tensor_.find(tensor->name()) == id2owned_tensor_.end()) { + id2owned_tensor_[tensor->name()] = tensor; + output_tensors_.insert(tensor->name()); + } + return *id2owned_tensor_[tensor->name()]; +} + +drr::Tensor &PatternGraph::AddTmpTensor( + const std::shared_ptr &tensor) { + IR_ENFORCE(id2owned_tensor_.count(tensor->name()) == 0); + id2owned_tensor_[tensor->name()] = tensor; + output_tensors_.insert(tensor->name()); + return *id2owned_tensor_[tensor->name()]; +} + +void PatternGraph::UpdateTmpTensor(const std::string &tmp_tensor_name, + const std::string &new_tensor_name) { + if (input_tensors_.count(tmp_tensor_name)) { + input_tensors_.erase(tmp_tensor_name); + input_tensors_.insert(new_tensor_name); + } + + output_tensors_.erase(new_tensor_name); + if (output_tensors_.count(tmp_tensor_name)) { + output_tensors_.erase(tmp_tensor_name); + output_tensors_.insert(new_tensor_name); + } + + auto tmp_tensor = id2owned_tensor_[tmp_tensor_name]; + id2owned_tensor_.erase(tmp_tensor_name); + tmp_tensor->set_name(new_tensor_name); + id2owned_tensor_[new_tensor_name] = tmp_tensor; +} + +size_t PatternGraph::CountOfOpCalls() const { return owned_op_call_.size(); } + +OpCall *SourcePatternGraph::AnchorNode() const { + for (const auto &output_tensor : output_tensors_) { + OpCall *output_op_candidate = + id2owned_tensor_.at(output_tensor)->producer(); + if (std::all_of(output_op_candidate->outputs().begin(), + output_op_candidate->outputs().end(), + [this](const Tensor *output) -> bool { + return this->output_tensors().count(output->name()); + })) + return output_op_candidate; + } + IR_THROW("Unable to find a valid anchor"); +} + +std::unordered_set SourcePatternGraph::OutputNodes() const { + std::unordered_set output_op_set; + for (const auto &output_tensor : output_tensors_) { + OpCall *output_op_candidate = + id2owned_tensor_.at(output_tensor)->producer(); + if (std::all_of(output_op_candidate->outputs().begin(), + output_op_candidate->outputs().end(), + [this](const Tensor *output) -> bool { + return this->output_tensors().count(output->name()); + })) + output_op_set.insert(output_op_candidate); + } + return output_op_set; +} + +void ResultPatternGraph::AssignTensor(const Tensor &from, const Tensor &to) { + if (to.producer() == nullptr) { + input_tensors_.insert(to.name()); + } + output_tensors_.erase(to.name()); + IR_ENFORCE(output_tensors_.count(from.name()) == 1, + "The Tensor (%s) which be assigned must be the output of result " + "pattern graph.", + from.name()); + tensor_assign_map_[from.name()] = to.name(); +} + +void GraphTopo::WalkGraphNodesTopoOrder( + const std::function &VisitNode) const { + // graph data + const std::unordered_set &inputs_tensor = + graph_->input_tensors(); + const std::unordered_map> + &id2owned_tensor = graph_->id2owend_tensor(); + const std::vector> &owend_opcall = + graph_->owned_op_call(); + + std::queue opcall_queue; + std::unordered_map> + opcall_dependent; + + // init opcall_dependent + for (const std::shared_ptr &opcall_sptr : owend_opcall) { + if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty + opcall_queue.push(opcall_sptr.get()); + } else { + for (const auto &pre_depd_tensor : opcall_sptr.get()->inputs()) { + opcall_dependent[opcall_sptr.get()].insert(pre_depd_tensor->name()); + } + } + } + + // init queue + for (const auto &tensor_name : inputs_tensor) { + IR_ENFORCE(id2owned_tensor.count(tensor_name), + "Drr input tensor [%s] must exists in pattern graph.", + tensor_name); + for (const auto &tensor_comsumer : + id2owned_tensor.at(tensor_name).get()->consumers()) { + opcall_dependent[tensor_comsumer].erase(tensor_name); + if (opcall_dependent[tensor_comsumer].empty()) { + opcall_queue.push(tensor_comsumer); + } + } + } + + while (!opcall_queue.empty()) { + const OpCall *opcall = opcall_queue.front(); + opcall_queue.pop(); + VisitNode(*opcall); + + // update opcall_dependent + for (const auto &output_tensor : opcall->outputs()) { + for (const auto &tensor_comsumer : output_tensor->consumers()) { + opcall_dependent[tensor_comsumer].erase(output_tensor->name()); + if (opcall_dependent[tensor_comsumer].empty()) { + opcall_queue.push(tensor_comsumer); + } + } + } + } +} + +std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { + os << "\nAll Tensors:\n"; + for (const auto &kv : pattern_graph.id2owend_tensor()) { + os << " " << kv.first; + } + os << "\n\n"; + + os << "Input Tensors:\n"; + for (const auto &tensor_name : pattern_graph.input_tensors()) { + os << " " << tensor_name; + } + os << "\n\n"; + + os << "Output Tensors:\n"; + for (const auto &tensor_name : pattern_graph.output_tensors()) { + os << " " << tensor_name; + } + os << "\n\n"; + + for (const auto &op_call : pattern_graph.owned_op_call()) { + os << " " << op_call->name() << " : "; + os << "inputs[ "; + for (const auto *input : op_call->inputs()) { + os << input->name() << " "; + } + os << "], "; + + os << "outputs[ "; + for (const auto &output : op_call->outputs()) { + os << output->name() << " "; + } + os << "]\n"; + } + os << "\n"; + return os; +} + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/drr/pattern_graph.h b/paddle/fluid/pir/drr/pattern_graph.h new file mode 100644 index 00000000000000..63bd60eadf17f3 --- /dev/null +++ b/paddle/fluid/pir/drr/pattern_graph.h @@ -0,0 +1,108 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace pir { +namespace drr { + +class Constraint; +class MatchContext; +class OpCall; +class Tensor; + +class PatternGraph { + public: + virtual ~PatternGraph() {} + + const drr::OpCall& AddOpCall(const std::shared_ptr& op_call); + + drr::Tensor& AddTensor(const std::shared_ptr& tensor); + + drr::Tensor& AddTmpTensor(const std::shared_ptr& tensor); + + void UpdateTmpTensor(const std::string& tmp_tensor_name, + const std::string& new_tensor_name); + + const std::unordered_set& input_tensors() const { + return input_tensors_; + } + + const std::unordered_set& output_tensors() const { + return output_tensors_; + } + + size_t CountOfOpCalls() const; + + const std::vector>& owned_op_call() const { + return owned_op_call_; + } + + const std::unordered_map>& + id2owend_tensor() const { + return id2owned_tensor_; + } + + protected: + std::unordered_map> id2owned_tensor_; + std::vector> owned_op_call_; + std::unordered_set input_tensors_; + std::unordered_set output_tensors_; +}; + +std::ostream& operator<<(std::ostream& os, const PatternGraph& pattern_graph); + +class SourcePatternGraph : public PatternGraph { + public: + OpCall* AnchorNode() const; + + std::unordered_set OutputNodes() const; + + private: + friend class DrrPatternContext; +}; + +class ResultPatternGraph : public PatternGraph { + public: + void AssignTensor(const Tensor& from, const Tensor& to); + + const std::unordered_map& tensor_assign_map() + const { + return tensor_assign_map_; + } + + private: + std::unordered_map tensor_assign_map_; +}; + +class GraphTopo { + public: + explicit GraphTopo(const PatternGraph* graph) : graph_(graph) {} + + void WalkGraphNodesTopoOrder( + const std::function& VisitNode) const; + + private: + const PatternGraph* graph_; +}; + +} // namespace drr +} // namespace pir diff --git a/paddle/fluid/pir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt index e1903c903de349..3140d9d20dc09e 100644 --- a/paddle/fluid/pir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -13,6 +13,11 @@ cc_library( SRCS constant_folding_pass.cc DEPS standalone_executor pd_op_to_kernel_pass transform_general_functions) +cc_library( + fused_gemm_epilogue_pass + SRCS fused_gemm_epilogue_pass.cc + DEPS drr) + cc_library( pd_inplace_pass SRCS inplace_pass.cc diff --git a/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.cc new file mode 100644 index 00000000000000..8585050e8efbf3 --- /dev/null +++ b/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.cc @@ -0,0 +1,295 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h" + +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class FusedLinearPattern : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op("pd_op.matmul", + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op("pd_op.add"); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + } +}; + +class FusedLinearGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul = pat.Op("pd_op.matmul", + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &matmul_grad = pat.Op("pd_op.matmul_grad", + {{"transpose_x", pat.Attr("trans_x")}, + {"transpose_y", pat.Attr("trans_y")}}); + const auto &add = pat.Op("pd_op.add"); + const auto &add_grad = pat.Op("pd_op.add_grad"); + + pat.Tensor("tmp") = matmul(pat.Tensor("x"), pat.Tensor("w")); + pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); + add_grad({&pat.Tensor("tmp"), &pat.Tensor("bias"), &pat.Tensor("out_grad")}, + {&pat.Tensor("tmp_grad"), &pat.Tensor("bias_grad")}); + matmul_grad({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("tmp_grad")}, + {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Tensor("w").Shape().size() == 2 && + match_ctx.Tensor("x").Shape().size() >= 2); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "none"; + }); + const auto &fused_gemm_epilogue = res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation", act_attr}}}); + const auto &fused_gemm_epilogue_grad = + res.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x")}, + {"trans_y", pat.Attr("trans_y")}, + {"activation_grad", act_attr}}}); + fused_gemm_epilogue( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out")}); + fused_gemm_epilogue_grad({&res.Tensor("x"), + &res.Tensor("w"), + &res.NoneTensor(), + &res.Tensor("out_grad")}, + {&res.Tensor("x_grad"), + &res.Tensor("w_grad"), + &res.Tensor("bias_grad")}); + } +}; + +class FusedLinearGeluGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &fused_gemm_epilogue = + pat.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", pat.Attr("act1")}}}); + const auto &fused_gemm_epilogue_grad1 = + pat.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", pat.Attr("act2")}}}); + fused_gemm_epilogue( + {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); + pat.Tensor("out") = pat.Op("pd_op.gelu")(pat.Tensor("fuse_out")); + + fused_gemm_epilogue_grad1({&pat.Tensor("x1"), + &pat.Tensor("w1"), + &pat.Tensor("reserve_space1"), + &pat.Tensor("out_grad")}, + {&pat.Tensor("x1_grad"), + &pat.Tensor("w1_grad"), + &pat.Tensor("bias1_grad")}); + pat.Tensor("gelu_dx") = pat.Op("pd_op.gelu_grad")(pat.Tensor("fuse_out"), + pat.Tensor("x1_grad")); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return match_ctx.Attr("act1") == "none" && + match_ctx.Attr("act2") == "none"; + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "gelu"; + }); + const auto &fused_gemm_epilogue_new = + res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", act_attr}}}); + const auto &act_grad_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "gelu_grad"; + }); + const auto &fused_gemm_epilogue_grad_new = + res.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", act_grad_attr}}}); + fused_gemm_epilogue_new( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out"), &res.Tensor("reserve_space2")}); + fused_gemm_epilogue_grad_new({&res.Tensor("x1"), + &res.Tensor("w1"), + &res.Tensor("reserve_space2"), + &res.Tensor("out_grad")}, + {&res.Tensor("gelu_dx"), + &res.Tensor("w1_grad"), + &res.Tensor("bias1_grad")}); + } +}; + +class FusedLinearReluGradPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &fused_gemm_epilogue = + pat.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", pat.Attr("act1")}}}); + const auto &fused_gemm_epilogue_grad = + pat.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", pat.Attr("act2")}}}); + const auto &fused_gemm_epilogue_grad1 = + pat.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x3")}, + {"trans_y", pat.Attr("trans_y3")}, + {"activation_grad", pat.Attr("act3")}}}); + fused_gemm_epilogue( + {&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("fuse_out"), &pat.Tensor("reserve_space")}); + pat.Tensor("out") = pat.Op("pd_op.relu")(pat.Tensor("fuse_out")); + + fused_gemm_epilogue_grad1({&pat.Tensor("x1"), + &pat.Tensor("w1"), + &pat.Tensor("reserve_space2"), + &pat.Tensor("out_grad")}, + {&pat.Tensor("x1_grad"), + &pat.Tensor("w1_grad"), + &pat.Tensor("bias1_grad")}); + pat.Tensor("relu_dx") = + pat.Op("pd_op.relu_grad")(pat.Tensor("x1"), pat.Tensor("x1_grad")); + fused_gemm_epilogue_grad({&pat.Tensor("x"), + &pat.Tensor("w"), + &pat.Tensor("reserve_space1"), + &pat.Tensor("relu_dx")}, + {&pat.Tensor("x_grad"), + &pat.Tensor("w_grad"), + &pat.Tensor("bias_grad")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return match_ctx.Attr("act1") == "none" && + match_ctx.Attr("act3") == "none"; + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &act_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "relu"; + }); + const auto &fused_gemm_epilogue_new = + res.Op("pd_op.fused_gemm_epilogue", + {{{"trans_x", pat.Attr("trans_x1")}, + {"trans_y", pat.Attr("trans_y1")}, + {"activation", act_attr}}}); + const auto &act_grad_attr = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any { + return "relu_grad"; + }); + const auto &fused_gemm_epilogue_grad1_new = + res.Op("pd_op.fused_gemm_epilogue_grad", + {{{"trans_x", pat.Attr("trans_x2")}, + {"trans_y", pat.Attr("trans_y2")}, + {"activation_grad", act_grad_attr}}}); + fused_gemm_epilogue_new( + {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, + {&res.Tensor("out"), &res.Tensor("reserve_space3")}); + fused_gemm_epilogue_grad1_new({&res.Tensor("x1"), + &res.Tensor("w1"), + &res.Tensor("reserve_space3"), + &res.Tensor("out_grad")}, + {&res.Tensor("relu_dx"), + &res.Tensor("w1_grad"), + &res.Tensor("bias1_grad")}); + } +}; + +class FusedGemmEpiloguePass : public pir::Pass { + public: + FusedGemmEpiloguePass() : pir::Pass("FusedGemmEpiloguePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(FusedLinearGradPattern().Build(context)); + ps.Add(FusedLinearPattern().Build(context)); + ps.Add(FusedLinearGeluGradPattern().Build(context)); + ps.Add(FusedLinearReluGradPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateFusedGemmEpiloguePass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(fused_gemm_epilogue, FusedGemmEpiloguePass); diff --git a/paddle/cinn/hlir/framework/convert_to_dialect.h b/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h similarity index 73% rename from paddle/cinn/hlir/framework/convert_to_dialect.h rename to paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h index 7ea0a2ace40c7a..61f503a530f729 100644 --- a/paddle/cinn/hlir/framework/convert_to_dialect.h +++ b/paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h @@ -15,19 +15,12 @@ #pragma once #include +#include "paddle/pir/core/dll_decl.h" namespace pir { -class Program; -} // namespace pir -namespace cinn { -namespace hlir { -namespace framework { -class Program; +class Pass; -std::unique_ptr<::pir::Program> ConvertToRuntimeDialect( - const hlir::framework::Program& program); +IR_API std::unique_ptr CreateFusedGemmEpiloguePass(); -} // namespace framework -} // namespace hlir -} // namespace cinn +} // namespace pir diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index f1566725f9326e..f70fc125689911 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -111,7 +111,7 @@ static std::unordered_set GetSkipDeletionValues(pir::Block* block) { continue; } if (upper_op_name == "pd_op.fetch" || - upper_op_name == "pd_op.shadow_output") { + upper_op_name == "builtin.shadow_output") { skip_dels.insert(op->operand_source(0)); continue; } @@ -349,7 +349,7 @@ class InplacePass : public pir::Pass { pir::BoolAttribute::get(pir::IrContext::Instance(), true)); } LOG_FIRST_N(INFO, 1) - << "Apply inplace pass on lowering ::ir::Program to Kernel Dialect."; + << "Apply inplace pass on lowering ::pir::Program to Kernel Dialect."; } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index eae6f20a34eaa7..3ac3db56cfd41d 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" + #include #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" @@ -28,7 +30,6 @@ #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" @@ -36,7 +37,10 @@ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/dialect/control_flow/ir/cf_ops.h" +#include "paddle/utils/flags.h" +PHI_DECLARE_bool(print_ir); namespace paddle { namespace dialect { @@ -61,13 +65,17 @@ const std::unordered_set UnchangeOutputOps = { "pd_op.fetch", "builtin.set_parameter", "builtin.get_parameter", - "pd_op.shadow_output"}; - -const std::unordered_set SpecialLowerOps = {"builtin.combine", - "builtin.slice", - "builtin.split", - "pd_op.if", - "cf.yield"}; + "builtin.shadow_output", + "cinn_runtime.jit_kernel"}; +const std::unordered_set SpecialLowerOps = { + "builtin.combine", + "builtin.slice", + "builtin.split", + "pd_op.if", + "pd_op.while", + "cf.yield", + "cf.cond_yield", + "cinn_runtime.jit_kernel"}; bool NeedFallBackCpu(const pir::Operation* op, const std::string& kernel_fn_name, @@ -106,7 +114,8 @@ phi::Backend GetDstBackend(const std::string& op_name, const OpYamlInfoParser* op_yaml_info_parser, phi::Backend kernel_def_backend, size_t input_index) { - if (op_name == "builtin.set_parameter" && + if ((op_name == "builtin.set_parameter" || + op_name == "builtin.shadow_output") && place.GetType() == phi::AllocationType::GPU) { // NOTE: align old executor, all the paramter are initilizered // on backend of executor place defined @@ -223,7 +232,7 @@ std::vector> GetFakeTensorList( return vec_res; } -pir::OpResult AddPlaceTransferOp(pir::OpResult in, +pir::OpResult AddPlaceTransferOp(pir::Value in, pir::Type out_type, const phi::Place& src_place, const phi::Place& dst_place, @@ -247,9 +256,9 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in, pir::Operation* op = pir::Operation::Create({in}, op_attribute, {out_type}, kernel_op_info); - if (in.owner()->HasAttribute(kAttrIsPersisable)) { - op->set_attribute(kAttrIsPersisable, - in.owner()->attribute(kAttrIsPersisable)); + auto in_op = in.dyn_cast().owner(); + if (in_op && in_op->HasAttribute(kAttrIsPersisable)) { + op->set_attribute(kAttrIsPersisable, in_op->attribute(kAttrIsPersisable)); } block->push_back(op); @@ -325,7 +334,7 @@ pir::Type BuildOutputType(pir::Type type, phi::DataType GetKernelDataTypeByYamlInfo( const pir::Operation* op, - const std::unordered_map& map_value_pair, + const std::unordered_map& map_value_pair, const dialect::OpYamlInfoParser* op_info_parser) { auto& attr_map = op->attributes(); auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; @@ -405,7 +414,7 @@ phi::DataType GetKernelDataTypeByYamlInfo( phi::Backend GetKernelBackendByYamlInfo( const pir::Operation* op, - const std::unordered_map& map_value_pair, + const std::unordered_map& map_value_pair, const dialect::OpYamlInfoParser* op_info_parser, const phi::Place& place) { auto& attr_map = op->attributes(); @@ -482,11 +491,12 @@ phi::KernelKey GetKernelKey( pir::Operation* op, const phi::Place& place, const std::string& kernel_fn_str, - const std::unordered_map& map_value_pair, + const std::unordered_map& map_value_pair, dialect::OpYamlInfoParser* op_info_parser = nullptr) { if (op->isa()) { // NOTE, for now feed op don't need a kernel, so the data type from Op // Result the next op use base program datatype + VLOG(6) << "FeedOp doesn't need a kernel. Backend: CPU, DataLayout: ANY"; return {phi::Backend::CPU, phi::DataLayout::ANY, TransToPhiDataType( @@ -496,6 +506,7 @@ phi::KernelKey GetKernelKey( if (op->isa()) { // NOTE, for now feed op don't need a kernel, so the data type from Op // Result the next op use base program datatype + VLOG(6) << "DataOp doesn't need a kernel"; auto data_place = op->attributes().at("place").dyn_cast().data(); @@ -507,7 +518,8 @@ phi::KernelKey GetKernelKey( op->result(0).type().dyn_cast().dtype())}; } - if (op->name() == "pd_op.seed") { + if (op->isa()) { + VLOG(6) << "SeedOp doesn't need a kernel"; auto backend = paddle::experimental::ParseBackend(place); return {backend, phi::DataLayout::ANY, @@ -516,6 +528,7 @@ phi::KernelKey GetKernelKey( } if (op->isa()) { + VLOG(6) << "FullWithTensorOp doesn't need a kernel"; auto backend = paddle::experimental::ParseBackend(place); auto dtype = op->attributes() .at("dtype") @@ -533,31 +546,24 @@ phi::KernelKey GetKernelKey( // only suppurt non vector input for now int tensor_input_number = static_cast(op_info_parser->InputTensorNumber()); - + VLOG(8) << "Begin to infer kernel key from op_info_parser(defined by yaml " + "info)"; // get datatype info kernel_data_type = GetKernelDataTypeByYamlInfo(op, map_value_pair, op_info_parser); + VLOG(8) << "Infer kernel data_type: [" << kernel_data_type + << "] from yaml info"; kernel_backend = GetKernelBackendByYamlInfo(op, map_value_pair, op_info_parser, place); - + VLOG(8) << "Infer kernel backend: [" << kernel_backend + << "] from yaml info"; // parse all the input tensor if (tensor_input_number == 0 || op->isa()) { // all the information have to get from attribute and context - - if (op->isa()) { - // try to process uniform, use shape to determin backend - // TODO(phlrain): shuold support other initilize op - auto define_op = - op->operand_source(0).dyn_cast().owner(); - if (define_op->isa()) { - auto shape = define_op->attribute("value") - .data() - .GetData(); - } - } - if (kernel_backend == phi::Backend::UNDEFINED) { kernel_backend = paddle::experimental::ParseBackend(place); + VLOG(8) << "Infer kernel backend: [" << kernel_backend + << "] when tensor_input_number == 0 or is Full_Op"; } } } @@ -566,15 +572,17 @@ phi::KernelKey GetKernelKey( kernel_data_type == phi::DataType::UNDEFINED) && op->num_operands() > 0) { paddle::experimental::detail::KernelKeyParser kernel_key_parser; - + VLOG(8) << "Begin to infer kernel key from op operands"; for (size_t i = 0; i < op->num_operands(); ++i) { // NOTE, only op with OpYamlInfo can have TensorArr if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) { + VLOG(8) << "input (" << i << ") doesn't have TensorArr"; continue; } auto input_tmp = op->operand_source(i); // NOTE: if not input_tmp, it's an optional input if (!input_tmp) { + VLOG(8) << "input (" << i << ") is NULL (optional input)"; continue; } auto new_input_tmp = map_value_pair.at(input_tmp); @@ -589,10 +597,12 @@ phi::KernelKey GetKernelKey( // don't know how to select the kernel in the next of op that // uses data op outout as inputs. So, we need set kernel backend // manually. - if (op->operand_source(i) - .dyn_cast() - .owner() - ->isa()) { + auto op_res = op->operand_source(i).dyn_cast(); + + if (!op_res) { + continue; + } + if (op_res.owner()->isa()) { auto data_op = op->operand_source(i).dyn_cast().owner(); auto data_place = data_op->attribute("place").data(); @@ -604,6 +614,8 @@ phi::KernelKey GetKernelKey( kernel_key_parser.key_set.backend_set = kernel_key_parser.key_set.backend_set | paddle::experimental::BackendSet(data_op_backend); + VLOG(8) << "Update kernel backend set from owner op (DataOp): " + << data_op_backend; } else if (op->operand_source(i) .dyn_cast() .owner() @@ -628,6 +640,8 @@ phi::KernelKey GetKernelKey( kernel_key_parser.key_set.backend_set = kernel_key_parser.key_set.backend_set | paddle::experimental::BackendSet(data_op_backend); + VLOG(8) << "Update kernel backend set from owner op (CombineOp): " + << data_op_backend; break; } } @@ -640,16 +654,26 @@ phi::KernelKey GetKernelKey( if (kernel_backend == phi::Backend::UNDEFINED) { kernel_backend = kernel_key.backend(); + if (kernel_backend != phi::Backend::UNDEFINED) { + VLOG(8) << "Infer kernel backend from op operands"; + } } if (kernel_layout == phi::DataLayout::UNDEFINED) { kernel_layout = kernel_key.layout(); + if (kernel_layout != phi::DataLayout::UNDEFINED) { + VLOG(8) << "Infer kernel layout from op operands"; + } } if (kernel_data_type == phi::DataType::UNDEFINED) { kernel_data_type = kernel_key.dtype(); + if (kernel_data_type != phi::DataType::UNDEFINED) { + VLOG(8) << "Infer kernel data_type from op operands"; + } } } if (kernel_backend == phi::Backend::UNDEFINED) { + VLOG(8) << "Kernel backend cannot be infered from op operands"; kernel_backend = paddle::experimental::ParseBackend(place); } @@ -657,13 +681,17 @@ phi::KernelKey GetKernelKey( if (op->isa()) { res.set_dtype(phi::DataType::FLOAT32); + VLOG(8) << "LoadCombineOp's kernel data type must be FLOAT32"; } if (NeedFallBackCpu((op), kernel_fn_str, res)) { res.set_backend(phi::Backend::CPU); + VLOG(8) << "kernel backend must be on CPU when need fallback"; } if (NeedFallBackFromGPUDNN2GPU(op, res)) { res.set_backend(phi::Backend::GPU); + VLOG(8) << "kernel backend must be on GPU when need fallback from GPUDNN " + "to GPU"; } return res; @@ -675,8 +703,9 @@ void HandleForIfOp( pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair) { auto old_cond = op_item->operand_source(0); + PADDLE_ENFORCE_EQ( map_value_pair->count(old_cond), true, @@ -745,11 +774,54 @@ void HandleForIfOp( } } -pir::OpResult GetNewInput( +void HandleForWhileOp( + const phi::Place& place, + pir::Operation* op_item, + pir::Block* block, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + std::vector vec_in; + pir::Value cond_val; + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + + PADDLE_ENFORCE_EQ( + map_value_pair->count(cur_in), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] op MUST in map pair", 0, op_item->name())); + auto new_in = map_value_pair->at(cur_in); + if (i == 0) + cond_val = new_in; + else + vec_in.push_back(new_in); + } + + pir::Builder builder(ctx, block); + + auto base_while_op = op_item->dyn_cast(); + auto new_while_op = builder.Build(cond_val, vec_in); + pir::Block* body_block = new_while_op.body_block(); + for (size_t i = 0; i < vec_in.size(); ++i) { + auto block_arg = body_block->AddArgument(vec_in[i].type()); + (*map_value_pair)[base_while_op.body_block()->argument(i)] = block_arg; + } + + // process body block + ProcessBlock(place, + base_while_op.body_block(), + body_block, + ctx, + map_op_pair, + map_value_pair); +} + +pir::Value GetNewInput( const pir::Value cur_in, - const std::unordered_map& map_value_pair, + const std::unordered_map& map_value_pair, const int index, - const std::string op_name) { + const std::string& op_name) { PADDLE_ENFORCE_EQ( map_value_pair.count(cur_in), true, @@ -765,11 +837,16 @@ void HandleForSpecialOp( pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair) { if (op_item->isa()) { HandleForIfOp(place, op_item, block, ctx, map_op_pair, map_value_pair); return; } + + if (op_item->isa()) { + HandleForWhileOp(place, op_item, block, ctx, map_op_pair, map_value_pair); + return; + } std::vector vec_inputs; std::vector op_output_types; if (op_item->isa<::pir::CombineOp>()) { @@ -782,7 +859,8 @@ void HandleForSpecialOp( vec_inputs.emplace_back(); continue; } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); vec_inputs.push_back(new_in); vec_inner_types.push_back(new_in.type()); } @@ -801,7 +879,8 @@ void HandleForSpecialOp( vec_inputs.emplace_back(); continue; } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); vec_inputs.push_back(new_in); if (new_in.type().isa()) { @@ -826,7 +905,8 @@ void HandleForSpecialOp( vec_inputs.emplace_back(); continue; } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); vec_inputs.push_back(new_in); if (new_in.type().isa()) { @@ -842,7 +922,7 @@ void HandleForSpecialOp( } } - if (op_item->name() == "cf.yield") { + if (op_item->isa<::pir::YieldOp>()) { if (op_item->num_operands() > 0) { for (size_t i = 0; i < op_item->num_operands(); ++i) { auto cur_in = op_item->operand_source(i); @@ -850,12 +930,35 @@ void HandleForSpecialOp( vec_inputs.emplace_back(); continue; } - auto new_in = GetNewInput(cur_in, *map_value_pair, i, op_item->name()); + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); vec_inputs.push_back(new_in); } } } + if (op_item->name() == "cinn_runtime.jit_kernel") { + if (op_item->num_operands() > 0) { + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + vec_inputs.push_back(new_in); + } + } + + for (size_t i = 0; i < op_item->num_results(); ++i) { + op_output_types.push_back(paddle::dialect::AllocatedDenseTensorType::get( + ctx, + place, + op_item->result(i).type().dyn_cast())); + } + } + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); // Generate new op pir::Operation* op = pir::Operation::Create( @@ -960,7 +1063,7 @@ std::vector BuildOpInputList( const OpYamlInfoParser* op_info_parser, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair, + std::unordered_map* map_value_pair, pir::Block* block) { if (op_item->num_operands() == 0) { return {}; @@ -979,6 +1082,7 @@ std::vector BuildOpInputList( true, phi::errors::PreconditionNotMet( "[%d]'s input of [%s] op MUST in map pair", i, op_item->name())); + auto new_in = map_value_pair->at(cur_in); auto new_in_type = new_in.type(); @@ -986,6 +1090,17 @@ std::vector BuildOpInputList( auto& kernel = phi::KernelFactory::Instance().SelectKernelWithGPUDNN( kernel_fn_str, kernel_key); + int tensor_param_index = i; + if (kernel.IsValid()) { + tensor_param_index = op_info_parser->GetTensorParamIndexByArgsName( + op_info_parser->InputNames()[i]); + // the input of op args is not the kernel parameter + if (tensor_param_index == -1) { + vec_inputs.emplace_back(new_in); + continue; + } + } + bool check_place_transfer = (op_item->isa<::pir::SetParameterOp>()) || (kernel.IsValid() && (!UnchangeOutputOps.count(op_item->name()))); @@ -1000,11 +1115,13 @@ std::vector BuildOpInputList( auto args_def = kernel.args_def(); auto input_defs = args_def.input_defs(); - auto dst_backend = GetDstBackend(op_item->name(), - place, - op_info_parser, - kernel.InputAt(i).backend, - i); + auto dst_backend = + GetDstBackend(op_item->name(), + place, + op_info_parser, + kernel.InputAt(tensor_param_index).backend, + i); + VLOG(6) << "Infer kernel backend from input " << i << " of op "; bool need_trans = (in_place.GetType() != phi::AllocationType::UNDEFINED) && @@ -1063,12 +1180,13 @@ std::vector BuildOpInputList( (op_info_parser != nullptr && !op_info_parser->IsTensorAttribute(i)) && (paddle::experimental::NeedTransformPlace( - place, kernel.InputAt(i).backend, {})); + place, kernel.InputAt(tensor_param_index).backend, {})); if (need_trans) { VLOG(6) << "need trans from " << place << " to " << kernel_key.backend(); // build memcopy op - auto out_place = phi::TransToPhiPlace(kernel.InputAt(i).backend); + auto out_place = phi::TransToPhiPlace( + kernel.InputAt(tensor_param_index).backend); pir::Type out_type; if (in_i_type.isa()) { out_type = dialect::AllocatedDenseTensorType::get( @@ -1121,12 +1239,13 @@ std::vector BuildOpInputList( auto args_def = kernel.args_def(); auto input_defs = args_def.input_defs(); - auto dst_backend = GetDstBackend(op_item->name(), - place, - op_info_parser, - kernel.InputAt(i).backend, - i); - + auto dst_backend = + GetDstBackend(op_item->name(), + place, + op_info_parser, + kernel.InputAt(tensor_param_index).backend, + i); + VLOG(6) << "Infer kernel backend from input " << i << " of op "; bool need_trans = (in_place.GetType() != phi::AllocationType::UNDEFINED) && (paddle::experimental::NeedTransformPlace( @@ -1150,8 +1269,9 @@ std::vector BuildOpInputList( new_in, out_type, in_place, out_place, kernel_key, block); } } else { - PADDLE_THROW(phi::errors::Unimplemented( - "only support allocated dense tensor type for now")); + PADDLE_THROW( + phi::errors::Unimplemented("only support allocated dense tensor " + "type and selected rows for now")); } } vec_inputs.push_back(new_in); @@ -1167,7 +1287,7 @@ void AddShadowFeed( pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair) { bool feed_op_add_shadow_feed = (op_item->isa()) && platform::is_gpu_place(place); bool data_op_add_shadow_feed = @@ -1249,7 +1369,7 @@ pir::Operation* BuildPhiKernelOp( pir::Block* block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair) { std::unordered_map op_attribute{ {"op_name", pir::StrAttribute::get(ctx, op_item->name())}, {"kernel_name", pir::StrAttribute::get(ctx, kernel_fn_str)}, @@ -1269,7 +1389,7 @@ pir::Operation* BuildPhiKernelOp( pir::OpInfo legacy_kernel_op_info = ctx->GetRegisteredOpInfo(paddle::dialect::LegacyKernelOp::name()); - pir::Operation* op; + pir::Operation* op = nullptr; if (dialect::IsLegacyOp(op_item->name())) { op = pir::Operation::Create( vec_inputs, op_attribute, op_output_types, legacy_kernel_op_info); @@ -1297,18 +1417,21 @@ void ProcessBlock( pir::Block* new_block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair) { + std::unordered_map* map_value_pair) { auto skip_feed_names = GetSkipFeedNames(block); for (auto op_item : *block) { VLOG(6) << "op name " << op_item->name(); if ((op_item->isa()) && SkipFeedOp(op_item, skip_feed_names)) { + VLOG(6) << "Skip FeedOp while lowering to kernel pass"; continue; } // HandleSpecialOp if (SpecialLowerOps.count(op_item->name())) { + VLOG(6) << "Handle Special Op: [" << op_item->name() + << "] while lowering to kernel pass"; HandleForSpecialOp( place, op_item, new_block, ctx, map_op_pair, map_value_pair); continue; @@ -1357,6 +1480,10 @@ void ProcessBlock( std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, phi::Place place) { + if (FLAGS_print_ir) { + std::cout << "IR before lowering = " << *prog << std::endl; + } + auto program = std::make_unique(pir::IrContext::Instance()); auto block = prog->block(); @@ -1366,16 +1493,15 @@ std::unique_ptr PdOpLowerToKernelPass(pir::Program* prog, ctx->GetOrRegisterDialect(); std::unordered_map map_op_pair; - std::unordered_map map_value_pair; + std::unordered_map map_value_pair; ProcessBlock( place, block, program->block(), ctx, &map_op_pair, &map_value_pair); - if (VLOG_IS_ON(2)) { - std::stringstream ss1; - program->Print(ss1); - VLOG(2) << "Program after lowering to kernel pass : " << ss1.str(); + if (FLAGS_print_ir) { + std::cout << "IR after lowering = " << *program << std::endl; } + return program; } } // namespace dialect diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h index 35b5484508a6f2..c1f0fe0cb85d94 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h @@ -28,6 +28,6 @@ void ProcessBlock( pir::Block* new_block, pir::IrContext* ctx, std::unordered_map* map_op_pair, - std::unordered_map* map_value_pair); + std::unordered_map* map_value_pair); } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.cc b/paddle/fluid/platform/device/xpu/xpu_op_list.cc index abe5bcd8c6c852..d585f2f4c64f9f 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.cc @@ -101,9 +101,15 @@ std::vector get_xpu_kp_op_support_type( std::vector get_xpu_op_support_type( const std::string& op_name, phi::backends::xpu::XPUVersion version) { - auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 - ? phi::backends::xpu::get_kl1_ops() - : phi::backends::xpu::get_kl2_ops(); + phi::backends::xpu::XPUOpMap ops; + if (version == phi::backends::xpu::XPUVersion::XPU1) { + ops = phi::backends::xpu::get_kl1_ops(); + } else if (version == phi::backends::xpu::XPUVersion::XPU2) { + ops = phi::backends::xpu::get_kl2_ops(); + } else { + ops = phi::backends::xpu::get_kl3_ops(); + } + std::vector res; if (ops.find(op_name) != ops.end()) { auto& dtypes = ops[op_name]; @@ -115,9 +121,15 @@ std::vector get_xpu_op_support_type( } XPUOpListMap get_xpu_op_list(phi::backends::xpu::XPUVersion version) { - auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 - ? phi::backends::xpu::get_kl1_ops() - : phi::backends::xpu::get_kl2_ops(); + phi::backends::xpu::XPUOpMap ops; + if (version == phi::backends::xpu::XPUVersion::XPU1) { + ops = phi::backends::xpu::get_kl1_ops(); + } else if (version == phi::backends::xpu::XPUVersion::XPU2) { + ops = phi::backends::xpu::get_kl2_ops(); + } else { + ops = phi::backends::xpu::get_kl3_ops(); + } + XPUOpListMap res; for (auto& op : ops) { std::vector op_types; diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index 3eb7f2d9f22721..a77e396adee5f4 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -419,7 +419,7 @@ void SendBroadCastCommID(std::vector servers, // connect with server std::vector connects; - for (auto server : servers) { + for (auto const& server : servers) { VLOG(3) << "connecting endpoint: " << server; int conn = ConnectAddr(server, head); connects.push_back(conn); diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 67512474567d30..44c17c32fa8d56 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -139,8 +139,8 @@ RecordMemEvent::RecordMemEvent(const void *ptr, } if (type == TracerMemEventType::Allocate) { - uint64_t current_allocated; - uint64_t peak_allocated; + uint64_t current_allocated = 0; + uint64_t peak_allocated = 0; uint64_t current_reserved = 0; // 0 means keep the same as before uint64_t peak_reserved = 0; // 0 means keep the same as before if (platform::is_cpu_place(place) || @@ -223,8 +223,8 @@ RecordMemEvent::RecordMemEvent(const void *ptr, peak_allocated, peak_reserved); } else if (type == TracerMemEventType::ReservedAllocate) { - uint64_t current_reserved; - uint64_t peak_reserved; + uint64_t current_reserved = 0; + uint64_t peak_reserved = 0; uint64_t current_allocated = 0; // 0 means keep the same as before uint64_t peak_allocated = 0; // 0 means keep the same as before if (platform::is_cpu_place(place) || @@ -306,8 +306,8 @@ RecordMemEvent::RecordMemEvent(const void *ptr, peak_allocated, peak_reserved); } else if (type == TracerMemEventType::Free) { - uint64_t current_allocated; - uint64_t peak_allocated; + uint64_t current_allocated = 0; + uint64_t peak_allocated = 0; uint64_t current_reserved = 0; // 0 means keep the same as before uint64_t peak_reserved = 0; // 0 means keep the same as before if (platform::is_cpu_place(place) || @@ -389,8 +389,8 @@ RecordMemEvent::RecordMemEvent(const void *ptr, peak_allocated, peak_reserved); } else if (type == TracerMemEventType::ReservedFree) { - uint64_t current_reserved; - uint64_t peak_reserved; + uint64_t current_reserved = 0; + uint64_t peak_reserved = 0; uint64_t current_allocated = 0; // 0 means keep the same as before uint64_t peak_allocated = 0; // 0 means keep the same as before if (platform::is_cpu_place(place) || diff --git a/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc b/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc index 7ea473dfdc1505..795aab1e128fd4 100644 --- a/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc +++ b/paddle/fluid/platform/profiler/custom_device/custom_tracer.cc @@ -28,6 +28,10 @@ namespace platform { CustomTracer::CustomTracer(const std::string& dev_type) : dev_type_(dev_type) { #ifdef PADDLE_WITH_CUSTOM_DEVICE + auto selected_devices = phi::DeviceManager::GetSelectedDeviceList(dev_type_); + if (selected_devices.size()) { + phi::DeviceManager::SetDevice(dev_type_, selected_devices[0]); + } phi::DeviceManager::ProfilerInitialize(dev_type_, &collector_, &context_); #endif } @@ -105,7 +109,7 @@ void CustomTracer::CollectTraceData(TraceEventCollector* collector) { for (auto de : collector_.DeviceEvents()) { collector->AddDeviceEvent(std::move(de)); } - for (auto tn : collector_.ThreadNames()) { + for (auto const& tn : collector_.ThreadNames()) { collector->AddThreadName(tn.first, tn.second); } collector_.ClearAll(); diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index ec3bd5741371eb..5a1a6e335abeb5 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -48,3 +48,4 @@ - reshape - erf - tanh +- sign diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 53369e956d7b82..64c431b3d237fe 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -585,9 +585,8 @@ void sigmoid_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { template void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - auto abs_tmp = abs(x); - auto divide_tmp = divide(x, abs_tmp); - set_output(out_grad * divide_tmp, x_grad); + auto sign_tmp = sign(x); + set_output(out_grad * sign_tmp, x_grad); } } diff --git a/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc index d76a8ad5523bb9..f89a898ca1a58e 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc @@ -27,12 +27,12 @@ namespace paddle { namespace prim { using Tensor = paddle::Tensor; template <> -Tensor empty(const paddle::experimental::IntArray& shape, - phi::DataType dtype, - const paddle::Place& place) { +TEST_API Tensor empty(const paddle::experimental::IntArray& shape, + phi::DataType dtype, + const paddle::Place& place) { framework::VarDesc* new_var = StaticCompositeContext::Instance().GetBlock()->Var( - std::move(StaticCompositeContext::Instance().GenerateUniqueName())); + StaticCompositeContext::Instance().GenerateUniqueName()); new_var->SetShape(shape.GetData()); new_var->SetDataType(framework::TransToProtoVarType(dtype)); // Place is not supported in static mode diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index c08405bb18dbed..b88292d488ab69 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -25,7 +25,6 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/type_defs.h" - namespace paddle { namespace prim { @@ -109,7 +108,7 @@ class StaticCompositeContext { static thread_local bool enable_bwd_prim_; static thread_local bool enable_fwd_prim_; static thread_local bool enable_eager_prim_; - static StaticCompositeContext* static_composite_context_; + TEST_API static StaticCompositeContext* static_composite_context_; DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); }; diff --git a/paddle/fluid/primitive/backend/manual/manual_backend.h b/paddle/fluid/primitive/backend/manual/manual_backend.h index 3c9340164ac012..4faabab79f6852 100644 --- a/paddle/fluid/primitive/backend/manual/manual_backend.h +++ b/paddle/fluid/primitive/backend/manual/manual_backend.h @@ -24,7 +24,7 @@ namespace primitive { namespace backend { using Tensor = paddle::Tensor; -using Scalar = paddle::experimental::Scalar; +using Scalar = phi::Scalar; using IntArray = paddle::experimental::IntArray; using DataType = phi::DataType; @@ -32,6 +32,13 @@ template std::vector add_n_grad(const std::vector& x, const Tensor& out_grad); +template +Tensor embedding_grad(const Tensor& x, + const Tensor& weight, + const Tensor& out_grad, + int64_t padding_idx = -1, + bool sparse = false); + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc index 7b33200336d000..b115e6a0210974 100644 --- a/paddle/fluid/primitive/backend/manual/manual_static_backend.cc +++ b/paddle/fluid/primitive/backend/manual/manual_static_backend.cc @@ -45,6 +45,23 @@ std::vector add_n_grad(const std::vector& x, return x_grad; } +template <> +Tensor embedding_grad(const Tensor& x, + const Tensor& weight, + const Tensor& out_grad, + int64_t padding_idx, + bool sparse) { + pir::Value x_res = std::static_pointer_cast(x.impl())->value(); + pir::Value weight_res = + std::static_pointer_cast(weight.impl())->value(); + pir::Value out_grad_res = + std::static_pointer_cast(out_grad.impl())->value(); + auto op_res = paddle::dialect::embedding_grad( + x_res, weight_res, out_grad_res, padding_idx, sparse); + Tensor out(std::make_shared(op_res)); + return out; +} + } // namespace backend } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index da9e12fa817c59..f9f0d5c32b11c9 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -37,104 +37,13 @@ # fmt: on -VJPS = [ - 'where_grad', - 'tril_grad', - 'triu_grad', - 'tile_grad', - 'tanh_grad', - 'mean_grad', - 'add_grad', - 'divide_grad', - 'sum_grad', - 'concat_grad', - 'split_grad', - 'split_with_num_grad', - 'gelu_grad', - 'softmax_grad', - 'silu_grad', - 'multiply_grad', - 'subtract_grad', - 'erf_grad', - 'expand_grad', - 'exp_grad', - 'expm1_grad', - 'elementwise_pow_grad', - 'fused_softmax_mask_upper_triangle_grad', - 'matmul_grad', - 'pow_grad', - 'rsqrt_grad', - 'slice_grad', - 'transpose_grad', - 'square_grad', - 'dropout_grad', - 'cast_grad', - 'slice_double_grad', - 'layer_norm_grad', - 'embedding_grad', - 'scale_grad', - 'gather_nd_grad', - 'stack_grad', - 'squeeze_grad', - 'unsqueeze_grad', - 'poisson_grad', - 'gumbel_softmax_grad', - 'conv2d_grad', - 'depthwise_conv2d_grad', - 'sqrt_grad', - 'flatten_grad', - 'relu_grad', - 'abs_grad', - 'log_grad', - 'clip_grad', - 'ceil_grad', - 'frobenius_norm_grad', - 'p_norm_grad', - 'maximum_grad', - 'argsort_grad', - 'min_grad', - 'batch_norm_grad', - 'max_pool2d_with_index_grad', - 'pool2d_grad', - 'minimum_grad', - 'prod_grad', - 'round_grad', - 'sin_grad', - 'cos_grad', - 'dot_grad', - 'floor_grad', - 'topk_grad', - 'square_grad', - 'gather_grad', - 'label_smooth_grad', - 'cross_entropy_with_softmax_grad', - 'mean_all_grad', - 'cumsum_grad', - 'linear_interp_grad', - 'bilinear_interp_grad', - 'trilinear_interp_grad', - 'nearest_interp_grad', - 'bicubic_interp_grad', - 'assign_grad', - 'assign_out__grad', - 'real_grad', - 'flip_grad', - 'softmax_grad', - 'expand_grad', - 'conv2d_transpose_grad', - 'depthwise_conv2d_transpose_grad', - 'sigmoid_grad', - 'pad_grad', - 'pad3d_grad', - 'einsum_grad', - 'leaky_relu_grad', - 'log10_grad', - 'conv3d_grad', - 'solve_grad', - 'diag_grad', - 'trace_grad', +VJPS_BLACK_LIST = [ + 'reshape_grad', + 'add_n_grad', ] +BACKENDS_BLACK_LIST = ['copy_to', 'add_n_grad', "allclose", "isclose"] + PRIM_VJP = [ 'divide_grad', @@ -148,153 +57,25 @@ 'tanh_grad', 'transpose_grad', 'concat_grad', + 'erf_grad', + 'exp_grad', + 'expand_grad', + 'log_grad', + 'gather_nd_grad', + 'pad_grad', + 'max_grad', + 'slice_grad', + 'tile_grad', ] # vjp list of primitive op CUSTOM_VJP = [ 'gelu_grad', 'layer_norm_grad', 'dropout_grad', -] # custom vjp list of composite op -VJP_COMPS = PRIM_VJP + CUSTOM_VJP - -BACKENDS = [ - 'where_grad', - 'tril_grad', - 'triu_grad', - 'tile_grad', - 'add_n', - 'mean', - 'sum', - 'divide', - 'full', - 'tanh', - 'tanh_grad', - 'mean_grad', - 'concat', - 'add', - 'multiply', - 'elementwise_pow', - 'scale', - 'reshape', - 'expand', - 'tile', - 'add_grad', - 'divide_grad', - 'sum_grad', - 'concat_grad', - 'split_grad', - 'split_with_num_grad', - 'gelu_grad', - 'softmax_grad', 'silu_grad', - 'multiply_grad', - 'subtract_grad', - 'erf_grad', - 'expand_grad', - 'exp_grad', - 'expm1_grad', - 'multiply', - 'exp', - 'erf', - 'cast', - 'elementwise_pow_grad', - 'fused_softmax_mask_upper_triangle_grad', - 'matmul_grad', - 'pow_grad', - 'reshape_grad', - 'rsqrt_grad', - 'slice_grad', - 'transpose_grad', - 'subtract', - 'assign', - 'equal', - 'greater_equal', - 'greater_than', - 'less_equal', - 'less_than', - 'matmul', - 'max', - 'maximum', - 'minimum', - 'not_equal', - 'abs', - 'bitwise_and', - 'bitwise_not', - 'bitwise_or', - 'bitwise_xor', - 'floor', - 'gather_nd', - 'log', - 'roll', - 'scatter', - 'scatter_nd_add', - 'square_grad', - 'dropout_grad', - 'slice', - 'layer_norm_grad', - 'embedding_grad', - 'sqrt', - 'uniform', - 'poisson_grad', - 'gumbel_softmax_grad', - 'split', - 'transpose', - 'gather_nd_grad', - 'stack_grad', - 'squeeze_grad', - 'unsqueeze_grad', - 'conv2d_grad', - 'depthwise_conv2d_grad', - 'sqrt_grad', - 'flatten_grad', - 'relu_grad', - 'abs_grad', - 'log_grad', - 'clip_grad', - 'ceil_grad', - 'frobenius_norm_grad', - 'p_norm_grad', - 'maximum_grad', - 'argsort_grad', - 'min_grad', - 'batch_norm_grad', - 'max_pool2d_with_index_grad', - 'pool2d_grad', - 'minimum_grad', - 'prod_grad', - 'round_grad', - 'sin_grad', - 'cos_grad', - 'dot_grad', - 'floor_grad', - 'topk_grad', - 'square_grad', - 'gather_grad', - 'label_smooth_grad', - 'cross_entropy_with_softmax_grad', - 'mean_all_grad', - 'cumsum_grad', - 'linear_interp_grad', - 'bilinear_interp_grad', - 'trilinear_interp_grad', - 'nearest_interp_grad', - 'bicubic_interp_grad', - 'assign_out__grad', - 'real_grad', 'softmax_grad', - 'conv2d_transpose_grad', - 'depthwise_conv2d_transpose_grad', - 'sigmoid_grad', - 'pad_grad', - 'pad3d_grad', - 'einsum_grad', - 'leaky_relu_grad', - 'log10_grad', - 'conv3d_grad', - 'solve_grad', - 'diag_grad', - 'trace_grad', - 'flip', -] + 'sqrt_grad', +] # custom vjp list of composite op +VJP_COMPS = PRIM_VJP + CUSTOM_VJP def load(path: pathlib.Path): @@ -346,6 +127,7 @@ def render(src_dir: pathlib.Path, dst_dir: pathlib.Path, *args, **kwargs): 'datatype': op_gen_tests.is_datatype, 'exist_mutable_attribute': op_gen_tests.exist_mutable_attribute, 'mutable_attribute': op_gen_tests.is_mutable_attribute, + 'only_composite_op': op_gen_tests.is_only_composite_op, } ) for tpl in env.list_templates( @@ -496,6 +278,22 @@ def process_backward_invoke_info(apis): api['invoke']['args'] = ', '.join(args) +def process_optional_output_info(apis): + for api in apis: + if not api['is_fwd']: + continue + inputs_dict = to_named_dict(api['inputs']) + for output in api['outputs']: + if ( + api.get("inplace", None) + and output['name'] in api['inplace'] + and inputs_dict[api['inplace'][output['name']]]['optional'] + ): + output['optional'] = True + else: + output['optional'] = False + + def gen( prim_path: pathlib.Path, fwd_path: pathlib.Path, @@ -544,12 +342,13 @@ def gen( apis = extend_compat_info(apis, compats) apis = apis + get_inplace_api(apis) process_backward_invoke_info(apis) + process_optional_output_info(apis) render( templates_dir, destination_dir, apis=apis, - backend_white_list=BACKENDS, - vjp_white_list=VJPS, + backend_black_list=BACKENDS_BLACK_LIST, + vjp_black_list=VJPS_BLACK_LIST, vjp_comp_white_list=VJP_COMPS, ) diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 index 25443f52fe8af7..863bbb7de633fb 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_backend.h.j2 @@ -15,24 +15,22 @@ namespace primitive { namespace backend { using Tensor = paddle::Tensor; -using Scalar = paddle::experimental::Scalar; +using Scalar = phi::Scalar; using IntArray = paddle::experimental::IntArray; using DataType = phi::DataType; {% for api in apis %} - {%- if api.name in backend_white_list -%} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} - {% if api.attrs is exist_mutable_attribute %} -{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, inplace_map, True, True)}}; + {%- if api is only_composite_op -%}{#- render nothing -#} + {%- elif api.name not in backend_black_list -%} + {%- if 'invoke' not in api or 'invoke' in api and api.is_fwd -%} + {% if api.attrs is exist_mutable_attribute %} +{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, True, True)}}; - {% endif %} -{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, inplace_map, False, True)}}; + {% endif %} +{{common.sig(api.name, api.inputs, api.outputs|trip_intermediate , api.attrs, False, True)}}; + {% endif %} + {% else %}{#- render nothing -#} {% endif %} {% endfor %} } // namespace backend diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 index 34e427f0c2e03b..3b9a94993eaa4e 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_eager_backend.cc.j2 @@ -16,9 +16,9 @@ namespace backend { {{common.sequence('', '', ', ', attrs)}} {%- endmacro -%} -{%- macro sig(name, inputs, attrs, outputs, inplace_map) -%} +{%- macro sig(name, inputs, attrs, outputs) -%} template <> -{{common.ret(outputs, inplace_map)}} {{name}}({{common.params(inputs, attrs, False)}}) +{{common.ret(outputs)}} {{name}}({{common.params(inputs, attrs, False)}}) {%- endmacro -%} {% macro body(name, inputs, attrs, outputs) %} @@ -27,21 +27,15 @@ template <> {%- set attr_names = [] -%} {%- for i in attrs -%} {%- do attr_names.append(i.name) -%} {%-endfor-%} {% filter indent(2, True) %} -VLOG(4) << "Eager Prim API {name}_ad_func call"; +VLOG(4) << "Eager Prim API {{name}}_ad_func call"; return ::{{name}}_ad_func({{common.args(input_names, attr_names)}}); {% endfilter %} {% endmacro %} {% for api in apis %} - {%- if api.is_prim and api.name in backend_white_list -%} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} -{{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate, inplace_map)}} { + {%- if api.is_prim and api.name not in backend_black_list and api.name[-1] != '_' -%} +{{sig(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}} { {{body(api.name, api.inputs, api.attrs, api.outputs | trip_intermediate)}} } diff --git a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 index 152cd241ad8333..97b150b0d2dfcc 100644 --- a/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/backend/generated/generated_static_backend.cc.j2 @@ -12,9 +12,9 @@ namespace backend { using LazyTensor = paddle::primitive::LazyTensor; -{%- macro sig(name, inputs, outputs, attrs, inplace_map, mutable_attribute_as_inputs=False) -%} +{%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False) -%} template <> -{{common.ret(outputs, inplace_map)}} {{name}}({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}}) +{{common.ret(outputs)}} {{name}}({{common.params(inputs, attrs, mutable_attribute_as_inputs, False)}}) {%- endmacro -%} {%- macro prepare_ir_api_inputs(inputs)-%} @@ -48,13 +48,13 @@ if({{input.name}}) { {%- macro get_static_backend_outputs(outputs)-%} {%- if outputs|length == 1 -%} - {%- if outputs[0].typename == 'Tensor' and not outputs[0].optional-%} + {%- if outputs[0].typename == 'Tensor' and not outputs[0].optional -%} Tensor {{outputs[0].name}}(std::make_shared(op_res)); return {{outputs[0].name}}; {%- elif outputs[0].typename == 'Tensor' and outputs[0].optional -%} paddle::optional {{outputs[0].name}}; if(op_res){ - {{outputs[0].name}} = paddle::make_optional(Tensor(std::make_shared(op_res.get())); + {{outputs[0].name}} = paddle::make_optional(Tensor(std::make_shared(op_res.get()))); } return {{outputs[0].name}}; {%- elif outputs[0].typename == 'Tensor[]' and not outputs[0].optional -%} @@ -80,7 +80,7 @@ return {{outputs[0].name}}; auto op_res_{{i}} = std::get<{{i}}>(op_res); {% if outputs[i].typename == 'Tensor' and not outputs[i].optional %} Tensor {{outputs[i].name}}(std::make_shared(op_res_{{i}})); - {% elif outputs[i].typename == 'Tensor' and outputs[i].optional %} + {% elif outputs[i].typename == 'Tensor' and outputs[i].optional %} paddle::optional {{outputs[i].name}}; if(op_res_{{i}}){ {{outputs[i].name}} = paddle::make_optional(Tensor(std::make_shared(op_res_{{i}}.get()))); @@ -139,28 +139,26 @@ auto op_res = paddle::dialect::{{name}}({{common.args(input_names, attr_names)}} {% for api in apis %} -{% if api.name in backend_white_list %} +{%- if api is only_composite_op -%}{#- render nothing -#} +{% elif api.name not in backend_black_list %} + {%- if 'invoke' not in api or 'invoke' in api and api.is_fwd-%} {% set api_outputs = api.outputs | trip_intermediate %} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} -{{sig(api.name, api.inputs, api_outputs, api.attrs, inplace_map)}} { +{{sig(api.name, api.inputs, api_outputs, api.attrs)}} { {% filter indent(2, True) %} {{body(api.name, api.inputs, api_outputs, api.attrs)}} {% endfilter %} } - {% if api.attrs is exist_mutable_attribute %} -{{sig(api.name, api.inputs, api_outputs, api.attrs, inplace_map, True)}} { + {% if api.attrs is exist_mutable_attribute %} +{{sig(api.name, api.inputs, api_outputs, api.attrs, True)}} { {% filter indent(2, True) %} {{body(api.name, api.inputs, api_outputs, api.attrs, True)}} {% endfilter %} } + {% endif %} {% endif %} +{% else %}{#- render nothing -#} {% endif %} {% endfor %} diff --git a/paddle/fluid/primitive/codegen/templates/common.j2 b/paddle/fluid/primitive/codegen/templates/common.j2 index 6ac639e8ceeaef..5f7148017ab23b 100644 --- a/paddle/fluid/primitive/codegen/templates/common.j2 +++ b/paddle/fluid/primitive/codegen/templates/common.j2 @@ -1,6 +1,6 @@ -{%- macro sig(name, inputs, outputs, attrs, inplace_map, mutable_attribute_as_inputs=False, default=False) -%} +{%- macro sig(name, inputs, outputs, attrs, mutable_attribute_as_inputs=False, default=False) -%} template -{{ret(outputs, inplace_map)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}}) +{{ret(outputs)}} {{name}}({{params(inputs, attrs, mutable_attribute_as_inputs, default)}}) {%- endmacro %} @@ -40,9 +40,9 @@ template {%- endmacro -%} -{%- macro ret(outputs, inplace_map) -%} +{%- macro ret(outputs) -%} {%- set names = [] -%} - {%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type(i.name in inplace_map and i.optional)) -%} {%- endfor -%} + {%- for i in outputs -%} {%- do names.append(i.typename|to_paddle_output_type(i.optional)) -%} {%- endfor -%} {%- if names|length > 1 -%} std::tuple<{{sequence('', '', ', ', names)}}> {%- else -%} @@ -73,5 +73,9 @@ std::tuple<{{sequence('', '', ', ', names)}}> {%- macro scalar2ir(name, data_type) -%} + {%- if data_type == 'std::vector' -%} +{{name}} + {%- else -%} {{name}}.to<{{data_type}}>() + {%- endif -%} {%- endmacro -%} diff --git a/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 b/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 index 5cf6807470f2bf..90c8d4ce5d89fa 100644 --- a/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/primitive/primitive.h.j2 @@ -13,18 +13,12 @@ using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArray; {% for api in apis %} -{%- if api.is_prim and api.name in backend_white_list and api.name[-1] != '_' -%} +{%- if api.is_prim and api.name not in backend_black_list and api.name[-1] != '_' -%} {%- set input_names = [] -%} {%- for i in api.inputs -%} {%- do input_names.append(i.name) -%} {%- endfor -%} {%- set attr_names = [] -%} {%- for i in api.attrs -%} {%- do attr_names.append(i.name) -%} {% endfor %} - {% set inplace_map = {} %} - {% if 'inplace' in api and api.inplace != None %} - {% for source, target in api.inplace.items() %} - {% do inplace_map.update({source: target}) %} - {% endfor %} - {% endif %} -{{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, inplace_map, False, True)}} { +{{common.sig(api.name, api.inputs, api.outputs | trip_intermediate, api.attrs, False, True)}} { return backend::{{api.name}}({{common.args(input_names, attr_names)}}); } diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index 50a0c5d86fc318..02e6c58f97af63 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -120,8 +120,10 @@ details::{{api.composite.func_name}}({{api.composite.func_args}}); {%- set api_map = {} -%} {%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%} {%- for api in apis %} - {%- if api.backward and api.backward in api_map and api.backward in vjp_white_list -%} + {%- if api.backward and api.backward in api_map and api.backward not in vjp_black_list -%} {%- set backward_api = api_map[api.backward] %} + {%- if backward_api is only_composite_op -%}{#- render nothing -#} + {%- else -%} {{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} { {% filter indent(2, True) %} {{body(backward_api)}} @@ -129,6 +131,7 @@ details::{{api.composite.func_name}}({{api.composite.func_args}}); } {% endif %} + {% endif %} {% endfor %} diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 index 7f403661fea05e..a4209fb5e81748 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.h.j2 @@ -20,11 +20,14 @@ std::vector> {{fwd_name}}_vjp({{common.params(inputs {%- set api_map = {} -%} {%- for api in apis -%} {%- do api_map.update({api.name: api}) -%} {%- endfor -%} {% for api in apis %} - {%- if api.backward and api.backward in api_map and api.backward in vjp_white_list -%} + {%- if api.backward and api.backward in api_map and api.backward not in vjp_black_list -%} {%- set backward_api = api_map[api.backward] -%} + {%- if backward_api is only_composite_op -%}{#- render nothing -#} + {%- else -%} {{sig(api.name, backward_api.name, backward_api.inputs, backward_api.attrs, backward_api.outputs)}} {% endif %} + {% endif %} {% endfor %} } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 7ac642573ca798..e0da626ef4c938 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -14,11 +14,55 @@ #pragma once -namespace paddle { +#include "paddle/fluid/primitive/primitive/primitive.h" +#include "paddle/fluid/primitive/type/lazy_tensor.h" +#include "paddle/fluid/primitive/utils/utils.h" +namespace paddle { namespace primitive { +namespace details { + +template +Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { + auto org_dtype = x.dtype(); + auto x_tmp = x; + bool need_cast = org_dtype == phi::DataType::FLOAT16 || + org_dtype == phi::DataType::BFLOAT16; + if (need_cast) { + x_tmp = cast(x, phi::DataType::FLOAT32); + } + std::vector x_dim = phi::vectorize(x_tmp.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + auto axis_ = std::vector(); + if (axis_size == 0) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + + int64_t value = 1; + for (size_t i = 0; i < axis_.size(); i++) { + value *= x_dim[axis_[i]]; + } + auto sum_x = sum(x_tmp, IntArray(axis_), x_tmp.dtype(), keepdim); + auto res = divide( + sum_x, full(phi::vectorize(sum_x.dims()), value, sum_x.dtype())); + if (need_cast) { + return cast(res, org_dtype); + } else { + return res; + } +} -namespace experimental {} +} // namespace details } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index ccf9673bafba07..85ffc28a20d20e 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -50,3 +50,5 @@ - tanh - full - cast +- sign +- slice diff --git a/paddle/fluid/primitive/rule/vjp/details.h b/paddle/fluid/primitive/rule/vjp/details.h index 4e2d7d4732b89a..5e8863027a78d1 100644 --- a/paddle/fluid/primitive/rule/vjp/details.h +++ b/paddle/fluid/primitive/rule/vjp/details.h @@ -536,6 +536,304 @@ void dropout_grad(const Tensor& mask, } } +template +void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto m_2_sqrt_pi = full(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype()); + auto neg_one = full(phi::vectorize(x.dims()), -1.0, x.dtype()); + auto neg_tmp = neg_one * x * x; + auto mul_tmp = m_2_sqrt_pi * exp(neg_tmp); + set_output(out_grad * mul_tmp, x_grad); + } +} + +template +void expand_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& shape, + Tensor* x_grad) { + if (x_grad) { + auto out_dims = phi::make_ddim(shape.GetData()); + if (out_dims != x.dims()) { + auto axes = get_reduce_dims(x.dims(), out_dims); + if (!axes.size()) { + by_pass(out_grad, x_grad); + } else { + auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false); + if (reduced.dims().size() != x.dims().size()) { + reduced = reshape(reduced, x.shape()); + } + set_output(reduced, x_grad); + } + } else { + by_pass(out_grad, x_grad); + } + } +} + +template +void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + // dx = dout / x + set_output(out_grad / x, x_grad); + } +} + +template +void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + if (out.dtype() == phi::DataType::FLOAT16 || + out.dtype() == phi::DataType::BFLOAT16) { + Tensor out_promote = cast(out, phi::DataType::FLOAT32); + Tensor out_grad_promote = cast(out_grad, phi::DataType::FLOAT32); + set_output(cast(out_promote * out_grad_promote, out.dtype()), + x_grad); + } else { + set_output(out_grad * out, x_grad); + } + } +} + +template +void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + // This calculation is important for resnet. + auto x_grad_tmp = (0.5 / out) * out_grad; + set_output(x_grad_tmp, x_grad); + } +} + +template +void silu_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + Tensor* x_grad) { + if (x_grad) { + auto org_dtype = x.dtype(); + bool need_cast = org_dtype == phi::DataType::FLOAT16 || + org_dtype == phi::DataType::BFLOAT16; + if (need_cast) { + auto x_cast = cast(x, phi::DataType::FLOAT32); + auto out_cast = cast(out, phi::DataType::FLOAT32); + auto out_grad_cast = cast(out_grad, phi::DataType::FLOAT32); + auto sigmoid = 1.0 / (1.0 + exp(-x_cast)); + auto res = out_grad_cast * sigmoid * (1.0 + x_cast - out_cast); + set_output(cast(res, org_dtype), x_grad); + } else { + auto sigmoid = 1.0 / (1.0 + exp(-x)); + auto res = out_grad * sigmoid * (1.0 + x - out); + set_output(res, x_grad); + } + } +} + +template +void softmax_grad(const Tensor& out, + const Tensor& out_grad, + int axis, + Tensor* x_grad) { + if (x_grad) { + if (out_grad.dims().size() > 0) { + if (axis >= 0) { + auto new_out_grad = out_grad * out; + auto tmp_x_grad = new_out_grad - + out * sum(new_out_grad, {axis}, out.dtype(), true); + set_output(tmp_x_grad, x_grad); + } else { + auto new_out_grad = out_grad * out; + auto tmp_x_grad = + new_out_grad - out * sum(new_out_grad, + {out.dims().size() + axis}, + out.dtype(), + true); + set_output(tmp_x_grad, x_grad); + } + } else { + set_output(out_grad * 0.0, x_grad); + } + } +} + +template +void gather_nd_grad(const Tensor& x, + const Tensor& index, + const Tensor& out_grad, + Tensor* x_grad) { + if (x_grad) { + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + auto x_grad_tmp = scatter_nd_add(zero_tensor, index, out_grad); + set_output(x_grad_tmp, x_grad); + } +} + +template +void pad_grad(const Tensor& input, + const Tensor& out_grad, + const std::vector& paddings, + const Scalar& pad_value, + Tensor* input_grad) { + if (input_grad) { + size_t rank = input.dims().size(); + auto out_dims = out_grad.dims(); + + std::vector starts(rank, 0); + std::vector ends(rank, 0); + std::vector axes(rank, 0); + std::vector infer_flags(rank, 1); + std::vector decrease_axis({}); + for (size_t i = 0; i < rank; ++i) { + starts[i] = static_cast(paddings[2 * i]); + ends[i] = static_cast(out_dims[i] - paddings[2 * i + 1]); + axes[i] = i; + } + auto out_tmp = + slice(out_grad, axes, starts, ends, infer_flags, decrease_axis); + set_output(out_tmp, input_grad); + } +} + +template +void max_grad(const Tensor& x, + const Tensor& out, + const Tensor& out_grad, + const IntArray& axis, + bool keepdim, + bool reduce_all, + Tensor* x_grad) { + if (!x_grad) { + return; + } + auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); + std::vector x_dim = phi::vectorize(x.dims()); + int64_t axis_size = axis.size(); + int64_t x_dim_size = x_dim.size(); + reduce_all = false; + if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { + reduce_all = true; + } else { + reduce_all = false; + } + auto x_grad_tmp = Tensor(); + if (x_dim_size == 0 || x_dim_size == 1 || keepdim) { + auto out_grad_tmp = out_grad.expand(IntArray(x_dim)); + auto out_tmp = out.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } else { + auto axis_ = std::vector(); + if (reduce_all) { + for (int64_t i = 0; i < x_dim_size; i++) { + axis_.push_back(i); + } + } else { + axis_ = axis.GetData(); + for (int64_t i = 0; i < axis_size; i++) { + if (axis[i] < 0) { + axis_[i] = axis[i] + x_dim_size; + } + } + } + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_ = reshape(out_grad, out_grad_shape); + auto out_ = reshape(out, out_grad_shape); + auto out_grad_tmp = out_grad_.expand(IntArray(x_dim)); + auto out_tmp = out_.expand(IntArray(x_dim)); + auto mask = equal(x, out_tmp); + x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); + } + set_output(x_grad_tmp, x_grad); +} + +template +void slice_grad(const Tensor& input, + const Tensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + Tensor* input_grad) { + if (input_grad) { + size_t rank = input.dims().size(); + auto out_dims = out_grad.dims(); + std::vector origin_out_shape; + auto in_dims = input.dims(); + + auto decrease_size = decrease_axis.size(); + if (decrease_size > 0) { + if (decrease_size == static_cast(in_dims.size())) { + // all dims decrease + out_dims = phi::make_ddim(std::vector(decrease_size, 1)); + } else { + origin_out_shape.resize(out_dims.size() + decrease_size, -1); + for (size_t i = 0; i < decrease_size; ++i) { + origin_out_shape[decrease_axis[i]] = 1; + } + + int index = 0; + for (size_t i = 0; i < origin_out_shape.size(); ++i) { + if (origin_out_shape[i] == -1) { + origin_out_shape[i] = out_dims[index]; + ++index; + } + } + out_dims = phi::make_ddim(origin_out_shape); + } + } + + std::vector offsets(rank, 0); + std::vector extents(rank, 0); + for (size_t i = 0; i < rank; ++i) { + offsets[i] = 0; + extents[i] = out_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + int axis = axes[i]; + int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; + start = std::max(start, static_cast(0)); + offsets[axis] = start; + } + + std::vector paddings; + for (size_t i = 0; i < rank; ++i) { + paddings.push_back(offsets[i]); + paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); + } + if (decrease_size > 0 && + (decrease_size != static_cast(in_dims.size()))) { + auto out_tmp = + pad(reshape(out_grad, origin_out_shape), paddings, 0.0); + set_output(out_tmp, input_grad); + } else { + auto out_tmp = pad(out_grad, paddings, 0.0); + set_output(out_tmp, input_grad); + } + } +} + +template +void tile_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& repeat_times, + Tensor* x_grad) { + if (x_grad) { + auto repeat_times_data = repeat_times.GetData(); + auto out_grad_shape = phi::vectorize(out_grad.dims()); + auto result = out_grad; + for (int i = 0; i < static_cast(repeat_times_data.size()); i++) { + int size = out_grad_shape[i] / repeat_times_data[i]; + std::vector sections(repeat_times_data[i], size); + auto split_arr = split(result, IntArray(sections), i); + result = full(phi::vectorize(split_arr[0].dims()), 0.0, x.dtype()); + for (int j = 0; j < static_cast(split_arr.size()); j++) { + result = split_arr[j] + result; + } + } + result = reshape(result, x.shape()); + set_output(result, x_grad); + } +} + } // namespace details } // namespace primitive } // namespace paddle diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc index 838b83d5d533b5..6b3b1050448ef7 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc @@ -55,6 +55,7 @@ std::vector> reshape_vjp( if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) { FLAGS_tensor_operants_mode = "static"; + VLOG(4) << "Call PIR Decomposed backward op reshape_grad"; paddle::Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; details::reshape_grad(xshape, out_grad, x_grad); diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 62b595a13f9602..09d76e33d69c1e 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -645,16 +645,20 @@ static void parse_attrs(PyObject *obj, phi::distributed::InferSpmdContext *ctx, const size_t arg_pos) { if (PyBool_Check(first_item)) { - auto attrs = CastPyArg2Booleans(obj, infer_spmd_string, arg_pos); + auto attrs = CastPyArg2Booleans( + obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attrs); } else if (PyCheckInteger(first_item)) { - auto attrs = CastPyArg2Ints(obj, infer_spmd_string, arg_pos); + auto attrs = + CastPyArg2Ints(obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attrs); } else if (PyLong_Check(first_item)) { - auto attrs = CastPyArg2Longs(obj, infer_spmd_string, arg_pos); + auto attrs = + CastPyArg2Longs(obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attrs); } else if (PyFloat_Check(first_item)) { - auto attrs = CastPyArg2Floats(obj, infer_spmd_string, arg_pos); + auto attrs = + CastPyArg2Floats(obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attrs); } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -671,16 +675,20 @@ static void parse_attr(PyObject *obj, phi::distributed::InferSpmdContext *ctx, const size_t arg_pos) { if (PyBool_Check(obj)) { - auto attr = CastPyArg2Boolean(obj, infer_spmd_string, arg_pos); + auto attr = CastPyArg2Boolean( + obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attr); } else if (PyCheckInteger(obj)) { - auto attr = CastPyArg2Int(obj, infer_spmd_string, arg_pos); + auto attr = + CastPyArg2Int(obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attr); } else if (PyLong_Check(obj)) { - auto attr = CastPyArg2Long(obj, infer_spmd_string, arg_pos); + auto attr = + CastPyArg2Long(obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attr); } else if (PyFloat_Check(obj)) { - auto attr = CastPyArg2Float(obj, infer_spmd_string, arg_pos); + auto attr = + CastPyArg2Float(obj, infer_spmd_string, static_cast(arg_pos)); ctx->EmplaceBackAttr(attr); } else { // TODO(ljz) support other types PADDLE_THROW(platform::errors::InvalidArgument( diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index 69b2e4b3c9786d..36beb74c7c05bd 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -155,7 +155,7 @@ py::dtype DistModelTypeToNumpyDType(DistModelDataType dtype) { py::array DistModelTensorGetData(DistModelTensor& tensor) { // NOLINT py::dtype dt = DistModelTypeToNumpyDType(tensor.dtype); - return py::array(std::move(dt), {tensor.shape}, tensor.data.data()); + return py::array(dt, {tensor.shape}, tensor.data.data()); } void BindFleetExecutor(py::module* m) { diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 9291338d70b656..bc18f368234c54 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -146,7 +146,7 @@ class IterableDatasetWrapper { if (tensors_[i][j]->place() == places_[read_num]) { result[read_num].emplace(slots_[j], std::move(*tensors_[i][j])); } else { - framework::TensorCopy(std::move(*tensors_[i][j]), + framework::TensorCopy(*tensors_[i][j], places_[read_num], &result[read_num][slots_[j]]); } diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 716d207d3b1960..a30f01084a060f 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -189,8 +189,8 @@ void CreateDistTensorWithNumpyValue(TensorObject* self, "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/CustomPlace")); } - auto dist_tensor = - std::make_shared(dense_tensor, dist_attr); + auto dist_tensor = std::make_shared( + std::make_shared(dense_tensor), dist_attr); self->tensor.set_impl(dist_tensor); if (!autograd_meta->GetMutableGradNode()) { @@ -280,13 +280,13 @@ void InitDistTensorWithTensor(TensorObject* self, if (place == src.place()) { std::shared_ptr tensor = std::static_pointer_cast(src.impl()); - self->tensor.set_impl(std::make_shared(*tensor, dist_attr)); + self->tensor.set_impl(std::make_shared(tensor, dist_attr)); VLOG(4) << "Same place, do ShareDataWith for DistTensor."; } else { std::shared_ptr tensor = std::static_pointer_cast( src.copy_to(place, true).impl()); - self->tensor.set_impl(std::make_shared(*tensor, dist_attr)); + self->tensor.set_impl(std::make_shared(tensor, dist_attr)); VLOG(4) << "Different place, do TensorCopy for DistTensor."; } if (src.get_autograd_meta()) { diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 1e1f40bf8e3d41..df3e62b3bae476 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -660,7 +660,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, VLOG(7) << "Custom operator add output " << output << " to CustomOpKernelContext. Add vector size = " << empty_tensors.size(); - ctx.EmplaceBackOutputs(std::move(empty_tensors)); + ctx.EmplaceBackOutputs(empty_tensors); continue; } } diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index e72f5dc77f99cb..199d05d2c98007 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -135,15 +135,15 @@ Returns a numpy array shows the value of current Tensor. same as current Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - data = paddle.uniform([30, 10, 32], dtype="float32", min=-1, max=1) - linear = paddle.nn.Linear(32, 64) - data = paddle.to_tensor(data) - x = linear(data) - print(x.numpy()) + >>> data = paddle.uniform([30, 10, 32], dtype="float32", min=-1, max=1) + >>> linear = paddle.nn.Linear(32, 64) + >>> data = paddle.to_tensor(data) + >>> x = linear(data) )DOC"); static PyObject* tensor_method_numpy(TensorObject* self, @@ -629,16 +629,17 @@ Reconstruct the self with other Tensor. It is a deep copy of 'self = other'. None. Examples: - .. code-block:: python - import paddle + .. code-block:: python - t1 = paddle.to_tensor([1.0], stop_gradient=False) - t2 = paddle.to_tensor([2.0], stop_gradient=True) + >>> import paddle - t1.reconstruct_from_(t2) + >>> t1 = paddle.to_tensor([1.0], stop_gradient=False) + >>> t2 = paddle.to_tensor([2.0], stop_gradient=True) - print(t1) + >>> t1.reconstruct_from_(t2) + >>> print(t1) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, [2.]) )DOC"); static PyObject* tensor_method_reconstruct_from_(TensorObject* self, @@ -706,28 +707,38 @@ Tn addition, the cloned Tensor provides gradient propagation. Tensor, The cloned Tensor. Examples: + .. code-block:: python - import paddle - - x = paddle.to_tensor(1.0, stop_gradient=False) - clone_x = x.clone() - y = clone_x**2 - y.backward() - print(clone_x.stop_gradient) # False - print(clone_x.grad) # [2.0], support gradient propagation - print(x.stop_gradient) # False - print(x.grad) # [2.0], clone_x support gradient propagation for x - - x = paddle.to_tensor(1.0) - clone_x = x.clone() - clone_x.stop_gradient = False - z = clone_x**3 - z.backward() - print(clone_x.stop_gradient) # False - print(clone_x.grad) # [3.0], support gradient propagation - print(x.stop_gradient) # True - print(x.grad) # None + >>> import paddle + + >>> x = paddle.to_tensor(1.0, stop_gradient=False) + >>> clone_x = x.clone() + >>> clone_x.retain_grads() + >>> y = clone_x**2 + >>> y.backward() + >>> print(clone_x.stop_gradient) + False + >>> print(clone_x.grad) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=False, 2.) + >>> print(x.stop_gradient) + False + >>> print(x.grad) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=False, 2.) + + >>> x = paddle.to_tensor(1.0) + >>> clone_x = x.clone() + >>> clone_x.stop_gradient = False + >>> z = clone_x**3 + >>> z.backward() + >>> print(clone_x.stop_gradient) + False + >>> print(clone_x.grad) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=False, 3.) + >>> print(x.stop_gradient) + True + >>> print(x.grad) + None )DOC"); static PyObject* tensor_method_clone(TensorObject* self, @@ -760,27 +771,32 @@ Enables this Tensor to have their grad populated during backward(). It is a no-o None. Examples: - .. code-block:: python - import paddle - - x = paddle.to_tensor([1.0, 2.0, 3.0]) - x.stop_gradient = False - y = x + x - y.retain_grads() - loss = y.sum() - loss.backward() - - print(y.grad) # [1., 1., 1.] - - x = paddle.to_tensor([1.0, 2.0, 3.0]) - x.stop_gradient = False - y = x + x - # y.retain_grads() - loss = y.sum() - loss.backward() + .. code-block:: python - print(y.grad) # None + >>> import paddle + + >>> x = paddle.to_tensor([1.0, 2.0, 3.0]) + >>> x.stop_gradient = False + >>> y = x + x + >>> y.retain_grads() + >>> loss = y.sum() + >>> loss.backward() + + >>> print(y.grad) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=False, + [1., 1., 1.]) + + >>> x = paddle.to_tensor([1.0, 2.0, 3.0]) + >>> x.stop_gradient = False + >>> y = x + x + >>> y.retain_grads() + >>> loss = y.sum() + >>> loss.backward() + + >>> print(y.grad) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=False, + [1., 1., 1.]) )DOC"); static PyObject* tensor_retain_grads(TensorObject* self, @@ -820,16 +836,26 @@ The Gradient of current Tensor will be set to ``0`` elementwise or ``None``. None. Examples: + .. code-block:: python - import paddle - input = paddle.uniform([10, 2]) - linear = paddle.nn.Linear(2, 3) - out = linear(input) - out.backward() - print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad)) - linear.weight.clear_gradient() - print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad)) + >>> import paddle + >>> input = paddle.uniform([10, 2]) + >>> linear = paddle.nn.Linear(2, 3) + >>> out = linear(input) + >>> out.backward() + >>> print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad)) + >>> # doctest: +SKIP("Random output") + Before clear_gradient, linear.weight.grad: Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[-0.03178465, -0.03178465, -0.03178465], + [-0.98546225, -0.98546225, -0.98546225]]) + >>> # doctest: -SKIP + >>> linear.weight.clear_gradient() + >>> print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad)) + After clear_gradient, linear.weight.grad: Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[0., 0., 0.], + [0., 0., 0.]]) + )DOC"); static PyObject* tensor_clear_gradient(TensorObject* self, @@ -844,7 +870,7 @@ static PyObject* tensor_clear_gradient(TensorObject* self, set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0); } - paddle::Tensor* grad; + paddle::Tensor* grad = nullptr; bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor); if (is_leaf) { grad = egr::EagerUtils::mutable_grad(self->tensor); @@ -1037,33 +1063,41 @@ In addition, the detached Tensor doesn't provide gradient propagation. Tensor, The detached Tensor. Examples: + .. code-block:: python - import paddle - - x = paddle.to_tensor([1.0], stop_gradient=False) - detach_x = x.detach() - detach_x[0] = 10.0 - print(x) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False, - # [10.]) - y = x**2 - y.backward() - print(x.grad) # [20.0] - print(detach_x.grad) # None, 'stop_gradient=True' by default - - detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad - z = detach_x**3 - z.backward() - - print(x.grad) # [20.0], detach_x is detached from x's graph, not affect each other - print(detach_x.grad) # [300.0], detach_x has its own graph - - # Due to sharing of data with origin Tensor, There are some unsafe operations: - # y = 2 * x - # detach_x[:] = 5.0 - # y.backward() - # It will raise Error: - # one of the variables needed for gradient computation has been modified by an inplace operation. + >>> import paddle + + >>> x = paddle.to_tensor([1.0], stop_gradient=False) + >>> detach_x = x.detach() + >>> detach_x[0] = 10.0 + >>> print(x) + Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False, [10.]) + + >>> y = x**2 + >>> y.backward() + >>> print(x.grad) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False, [20.]) + + >>> print(detach_x.grad) # None, 'stop_gradient=True' by default + None + + >>> detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad + >>> z = detach_x**3 + >>> z.backward() + + >>> print(x.grad) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False, [20.]) + + >>> print(detach_x.grad) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False, [300.]) + + >>> # Due to sharing of data with origin Tensor, There are some unsafe operations: + >>> # y = 2 * x + >>> # detach_x[:] = 5.0 + >>> # y.backward() + >>> # It will raise Error: + >>> # one of the variables needed for gradient computation has been modified by an inplace operation. )DOC"); static PyObject* tensor_method_detach(TensorObject* self, @@ -1132,13 +1166,19 @@ Returns the underline tensor in the origin Tensor. Underline tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1.0], stop_gradient=False) - underline_x = x.get_tensor() - print(underline_x) # a Dense Tensor info + >>> x = paddle.to_tensor([1.0], stop_gradient=False) + >>> underline_x = x.get_tensor() + >>> print(underline_x) + - place: Place(cpu) + - shape: [1] + - layout: NCHW + - dtype: float32 + - data: [1] )DOC"); static PyObject* tensor_method_get_underline_tensor(TensorObject* self, @@ -1729,7 +1769,7 @@ static PyObject* tensor_register_grad_hook(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - int64_t hook_id; + int64_t hook_id = 0; if (egr::EagerUtils::IsLeafTensor(self->tensor)) { VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name(); @@ -2022,16 +2062,17 @@ Returns the total number of non zero elements in input SparseCooTensor/SparseCsr int Examples: + .. code-block:: python - import paddle + >>> import paddle - indices = [[0, 1, 2], [1, 2, 0]] - values = [1.0, 2.0, 3.0] - dense_shape = [3, 3] - coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) - coo.nnz() - # 3 + >>> indices = [[0, 1, 2], [1, 2, 0]] + >>> values = [1.0, 2.0, 3.0] + >>> dense_shape = [3, 3] + >>> coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + >>> coo.nnz() + 3 )DOC"); @@ -2069,18 +2110,19 @@ Returns the indices of non zero elements in input SparseCooTensor. DenseTesnor Examples: + .. code-block:: python - import paddle + >>> import paddle - indices = [[0, 1, 2], [1, 2, 0]] - values = [1.0, 2.0, 3.0] - dense_shape = [3, 3] - coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) - coo.indices() - # Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [[0, 1, 2], - # [1, 2, 0]]) + >>> indices = [[0, 1, 2], [1, 2, 0]] + >>> values = [1.0, 2.0, 3.0] + >>> dense_shape = [3, 3] + >>> coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + >>> coo.indices() + Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[0, 1, 2], + [1, 2, 0]]) )DOC"); @@ -2112,17 +2154,18 @@ Returns the values of non zero elements in input SparseCooTensor. DenseTesnor Examples: + .. code-block:: python - import paddle + >>> import paddle - indices = [[0, 1, 2], [1, 2, 0]] - values = [1.0, 2.0, 3.0] - dense_shape = [3, 3] - coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) - coo.values() - # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, - # [1., 2., 3.]) + >>> indices = [[0, 1, 2], [1, 2, 0]] + >>> values = [1.0, 2.0, 3.0] + >>> dense_shape = [3, 3] + >>> coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + >>> coo.values() + Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + [1., 2., 3.]) )DOC"); @@ -2164,18 +2207,19 @@ Returns the compressed row index of non zero elements in input SparseCsrTensor. DenseTesnor Examples: + .. code-block:: python - import paddle + >>> import paddle - crows = [0, 2, 3, 5] - cols = [1, 3, 2, 0, 1] - values = [1, 2, 3, 4, 5] - dense_shape = [3, 4] - csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) - csr.crows() - # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [0, 2, 3, 5]) + >>> crows = [0, 2, 3, 5] + >>> cols = [1, 3, 2, 0, 1] + >>> values = [1, 2, 3, 4, 5] + >>> dense_shape = [3, 4] + >>> csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + >>> csr.crows() + Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [0, 2, 3, 5]) )DOC"); @@ -2207,18 +2251,19 @@ Returns the column index of non zero elements in input SparseCsrTensor. DenseTesnor Examples: + .. code-block:: python - import paddle + >>> import paddle - crows = [0, 2, 3, 5] - cols = [1, 3, 2, 0, 1] - values = [1, 2, 3, 4, 5] - dense_shape = [3, 4] - csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) - csr.cols() - # Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True, - # [1, 3, 2, 0, 1]) + >>> crows = [0, 2, 3, 5] + >>> cols = [1, 3, 2, 0, 1] + >>> values = [1, 2, 3, 4, 5] + >>> dense_shape = [3, 4] + >>> csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + >>> csr.cols() + Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [1, 3, 2, 0, 1]) )DOC"); @@ -2246,12 +2291,14 @@ Whether the Tensor is a Dense Tensor. Whether the Tensor is a Dense Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1.0], stop_gradient=False) - print(x.is_dense()) + >>> x = paddle.to_tensor([1.0], stop_gradient=False) + >>> print(x.is_dense()) + True )DOC"); static PyObject* tensor_method_is_dense(TensorObject* self, @@ -2274,12 +2321,14 @@ Whether the Tensor is a Distributed Tensor. Whether the Tensor is a Distributed Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1.0], stop_gradient=False) - print(x.is_dist()) # False + >>> x = paddle.to_tensor([1.0], stop_gradient=False) + >>> print(x.is_dist()) + False )DOC"); static PyObject* tensor_method_is_dist(TensorObject* self, @@ -2305,16 +2354,17 @@ When input is SparseCooTensor/SparseCsrTensor, will return True. When input is D bool Examples: + .. code-block:: python - import paddle + >>> import paddle - indices = [[0, 1, 2], [1, 2, 0]] - values = [1.0, 2.0, 3.0] - dense_shape = [3, 3] - coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) - coo.is_sparse() - # True + >>> indices = [[0, 1, 2], [1, 2, 0]] + >>> values = [1.0, 2.0, 3.0] + >>> dense_shape = [3, 3] + >>> coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + >>> coo.is_sparse() + True )DOC"); static PyObject* tensor_method_is_sparse(TensorObject* self, @@ -2341,16 +2391,17 @@ When input is SparseCooTensor, will return True. When input is DenseTensor/Spars bool Examples: + .. code-block:: python - import paddle + >>> import paddle - indices = [[0, 1, 2], [1, 2, 0]] - values = [1.0, 2.0, 3.0] - dense_shape = [3, 3] - coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) - coo.is_sparse_coo() - # True + >>> indices = [[0, 1, 2], [1, 2, 0]] + >>> values = [1.0, 2.0, 3.0] + >>> dense_shape = [3, 3] + >>> coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + >>> coo.is_sparse_coo() + True )DOC"); @@ -2377,17 +2428,18 @@ When input is SparseCsrTensor, will return True. When input is DenseTensor/Spars bool Examples: + .. code-block:: python - import paddle + >>> import paddle - crows = [0, 2, 3, 5] - cols = [1, 3, 2, 0, 1] - values = [1, 2, 3, 4, 5] - dense_shape = [3, 4] - csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) - csr.is_sparse_csr() - # True + >>> crows = [0, 2, 3, 5] + >>> cols = [1, 3, 2, 0, 1] + >>> values = [1, 2, 3, 4, 5] + >>> dense_shape = [3, 4] + >>> csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape) + >>> csr.is_sparse_csr() + True )DOC"); @@ -2417,19 +2469,20 @@ When input is SparseCooTensor, will convert `COO` to `CSR` . When input is Dense SparseCsrTensor Examples: + .. code-block:: python - import paddle + >>> import paddle - indices = [[0, 1, 2], [1, 2, 0]] - values = [1.0, 2.0, 3.0] - dense_shape = [3, 3] - coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) - coo.to_sparse_csr() - # Tensor(shape=[3, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 1, 2, 3], - # cols=[1, 2, 0], - # values=[1., 2., 3.]) + >>> indices = [[0, 1, 2], [1, 2, 0]] + >>> values = [1.0, 2.0, 3.0] + >>> dense_shape = [3, 3] + >>> coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape) + >>> coo.to_sparse_csr() + Tensor(shape=[3, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True, + crows=[0, 1, 2, 3], + cols=[1, 2, 0], + values=[1., 2., 3.]) )DOC"); @@ -2466,17 +2519,17 @@ Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are suppor .. code-block:: python - import paddle + >>> import paddle - x = paddle.rand([2, 3, 8]) - y = paddle.rand([2, 3, 8]) - y = y.to_sparse_csr() - z = paddle.rand([2, 5]) + >>> x = paddle.rand([2, 3, 8]) + >>> y = paddle.rand([2, 3, 8]) + >>> y = y.to_sparse_csr() + >>> z = paddle.rand([2, 5]) - x.is_same_shape(y) - # True - x.is_same_shape(z) - # False + >>> x.is_same_shape(y) + True + >>> x.is_same_shape(z) + False )DOC"); @@ -2509,24 +2562,30 @@ Returns the size in bytes of an element in the Tensor. int, The size in bytes of an element in the Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor(1, dtype='bool') - x.element_size() # 1 + >>> x = paddle.to_tensor(1, dtype='bool') + >>> x.element_size() + 1 - x = paddle.to_tensor(1, dtype='float16') - x.element_size() # 2 + >>> x = paddle.to_tensor(1, dtype='float16') + >>> x.element_size() + 2 - x = paddle.to_tensor(1, dtype='float32') - x.element_size() # 4 + >>> x = paddle.to_tensor(1, dtype='float32') + >>> x.element_size() + 4 - x = paddle.to_tensor(1, dtype='float64') - x.element_size() # 8 + >>> x = paddle.to_tensor(1, dtype='float64') + >>> x.element_size() + 8 - x = paddle.to_tensor(1, dtype='complex128') - x.element_size() # 16 + >>> x = paddle.to_tensor(1, dtype='complex128') + >>> x.element_size() + 16 )DOC"); static PyObject* tensor_method_element_size(TensorObject* self, @@ -2753,12 +2812,16 @@ Returns the address of the first element of current Tensor. int, The address of the first element of current Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1, 2, 3]) - print(x.data_ptr()) + >>> x = paddle.to_tensor([1, 2, 3]) + >>> print(x.data_ptr()) + >>> # doctest: +SKIP('return the address') + 93220864 + >>> # doctest: -SKIP )DOC"); static PyObject* tensor_data_ptr(TensorObject* self, @@ -2800,13 +2863,15 @@ Returns the strides of current Tensor. List, the strides of current Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1, 2, 3]) - y = x[1] - print(y.get_strides()) + >>> x = paddle.to_tensor([1, 2, 3]) + >>> y = x[1] + >>> print(y.get_strides()) + [] )DOC"); static PyObject* tensor_method_strides(TensorObject* self, @@ -2838,14 +2903,16 @@ If self tensor is already contiguous, this function returns the current Tensor. Tensor, The contiguous Tensor. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1, 2, 3]) - y = x[1] - y = y.contiguous() - print(y) + >>> x = paddle.to_tensor([1, 2, 3]) + >>> y = x[1] + >>> y = y.contiguous() + >>> print(y) + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 2) )DOC"); static PyObject* tensor_contiguous(TensorObject* self, @@ -2883,13 +2950,14 @@ Whether the Tensor is contiguous. Bool, Whether the Tensor is contiguous. Examples: + .. code-block:: python - import paddle + >>> import paddle - x = paddle.to_tensor([1, 2, 3]) - y = x[1] - print(y.is_contiguous()) + >>> x = paddle.to_tensor([1, 2, 3]) + >>> y = x[1] + >>> print(y.is_contiguous()) )DOC"); static PyObject* tensor_is_contiguous(TensorObject* self, PyObject* args, diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 7d70ed174a4c81..46170298ce42b6 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -237,8 +237,8 @@ float CastPyArg2AttrFloat(PyObject* obj, ssize_t arg_pos) { std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos) { if (PyObject_CheckStr(obj)) { - Py_ssize_t size; - const char* data; + Py_ssize_t size = 0; + const char* data = nullptr; data = PyUnicode_AsUTF8AndSize(obj, &size); return std::string(data, static_cast(size)); } else { @@ -1842,7 +1842,7 @@ paddle::Tensor PyTensorHook::operator()(const paddle::Tensor& var) { res = PyObject_CallFunctionObjArgs(py_func_, p_tmp_var, nullptr); Py_DECREF(p_tmp_var); } catch (platform::EnforceNotMet& e) { - throw std::move(e); + throw e; } catch (std::exception& e) { PADDLE_THROW(platform::errors::Unavailable( "Hook function of Tensor raises an exception: %s.", e.what())); @@ -1869,7 +1869,7 @@ void PyVoidHook::operator()() { try { PyObject_CallFunctionObjArgs(py_func_, nullptr); } catch (platform::EnforceNotMet& e) { - throw std::move(e); + throw e; } catch (std::exception& e) { PADDLE_THROW(platform::errors::Unavailable( "Hook function of Tensor raises an exception: %s.", e.what())); @@ -2079,9 +2079,9 @@ void DistTensorConverter::convert(Tensor* x) { phi::distributed::TensorDistAttr dist_attr( phi::vectorize(x->impl()->dims())); dist_attr.set_process_mesh(*mesh); - auto dense_t = static_cast(x->impl().get()); + auto dense_t = std::static_pointer_cast(x->impl()); x->set_impl( - std::make_shared(*dense_t, dist_attr)); + std::make_shared(dense_t, dist_attr)); } } diff --git a/paddle/fluid/pybind/eval_frame.c b/paddle/fluid/pybind/eval_frame.c index a07f2033e4b4e8..5b4f216be24dc7 100644 --- a/paddle/fluid/pybind/eval_frame.c +++ b/paddle/fluid/pybind/eval_frame.c @@ -184,8 +184,13 @@ int Internal_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) { if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) { /* Free vars have not been initialized -- Do that */ PyCodeObject *co = frame->f_code; +#if PY_VERSION_HEX >= 0x030c0000 + PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure; + int offset = co->co_nlocals + co->co_ncellvars; +#else PyObject *closure = frame->f_func->func_closure; int offset = co->co_nlocals + co->co_nplaincellvars; +#endif for (int i = 0; i < co->co_nfreevars; ++i) { PyObject *o = PyTuple_GET_ITEM(closure, i); Py_INCREF(o); @@ -269,6 +274,8 @@ PyFrameObject *Internal_PyFrame_New_NoTrack(PyCodeObject *code) { return f; } +#if PY_VERSION_HEX < 0x030c0000 + PyFrameObject *Internal_PyFrame_MakeAndSetFrameObject( _PyInterpreterFrame *frame) { assert(frame->frame_obj == NULL); @@ -387,6 +394,8 @@ void Internal_PyFrame_Clear(_PyInterpreterFrame *frame) { Py_DECREF(frame->f_code); } +#endif + #else typedef PyFrameObject FrameObject; #endif @@ -449,9 +458,11 @@ inline static PyObject *eval_custom_code_py311_plus(PyThreadState *tstate, // Create a new function object from code object. Refer to MAKE_FUNCTION. PyFunctionObject *func = (PyFunctionObject *)PyFunction_New((PyObject *)code, frame->f_globals); +#if PY_VERSION_HEX < 0x030c0000 Py_XINCREF(frame->f_func->func_closure); func->func_closure = frame->f_func->func_closure; _PyFrame_InitializeSpecials(shadow, func, NULL, code->co_nlocalsplus); +#endif PyObject **fastlocals_old = frame->localsplus; PyObject **fastlocals_new = shadow->localsplus; @@ -483,7 +494,9 @@ inline static PyObject *eval_custom_code_py311_plus(PyThreadState *tstate, } PyObject *result = eval_frame_default(tstate, shadow, throw_flag); +#if PY_VERSION_HEX < 0x030c0000 Internal_PyFrame_Clear(shadow); +#endif free(shadow); Py_DECREF(func); Py_DECREF(namemap); @@ -558,7 +571,11 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, // original frame. So we pass a PyInterpreterFrame to // _PyFrame_FastToLocalsWithError directly. But this is an internal API, so we // copy many code from CPython project into our project. +#if PY_VERSION_HEX >= 0x030c0000 + if (true) { +#else if (Internal_PyFrame_FastToLocalsWithError(frame) < 0) { +#endif #else if (PyFrame_FastToLocalsWithError(frame) < 0) { #endif @@ -605,7 +622,7 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyCodeObject *code = (PyCodeObject *)PyObject_GetAttrString(result, "code"); PyObject *disable_eval_frame = PyObject_GetAttrString(result, "disable_eval_frame"); - PyObject *out; + PyObject *out = NULL; // VLOG(7) << "Start eval new frame and code."; if (disable_eval_frame != Py_True) { // Re-enable custom behavior diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 80e6de07919611..55efda46c86b07 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -108,7 +108,7 @@ class PyVariableWrapperHook : public imperative::VariableWrapperHook { res = PyObject_CallFunctionObjArgs( py_func_, py::cast(tmp_varbase).ptr(), nullptr); } catch (platform::EnforceNotMet &e) { - throw std::move(e); + throw e; } catch (std::exception &e) { PADDLE_THROW(platform::errors::Unavailable( "Hook function of Tensor raises an exception: %s.", e.what())); diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 4e977b010de0fa..019b5098feb75f 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -206,7 +206,7 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { py::array PaddleTensorGetData(PaddleTensor &tensor) { // NOLINT py::dtype dt = PaddleDTypeToNumpyDType(tensor.dtype); - return py::array(std::move(dt), {tensor.shape}, tensor.data.data()); + return py::array(dt, {tensor.shape}, tensor.data.data()); } template @@ -214,7 +214,7 @@ void ZeroCopyTensorCreate(ZeroCopyTensor &tensor, // NOLINT py::array_t data) { std::vector shape; std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape)); - tensor.Reshape(std::move(shape)); + tensor.Reshape(shape); tensor.copy_from_cpu(static_cast(data.data())); } @@ -235,7 +235,7 @@ void PaddleInferTensorCreate(paddle_infer::Tensor &tensor, // NOLINT py::array_t data) { std::vector shape; std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape)); - tensor.Reshape(std::move(shape)); + tensor.Reshape(shape); tensor.CopyFromCpu(static_cast(data.data())); } @@ -1265,8 +1265,8 @@ void BindPaddlePassBuilder(py::module *m) { .def("set_passes", [](PaddlePassBuilder &self, const std::vector &passes) { self.ClearPasses(); - for (auto pass : passes) { - self.AppendPass(std::move(pass)); + for (auto const &pass : passes) { + self.AppendPass(pass); } }) .def("append_pass", &PaddlePassBuilder::AppendPass) @@ -1318,6 +1318,11 @@ void BindInternalUtils(py::module *m) { .def_static("set_transformer_maskid", [](paddle_infer::Config &config, std::string tensor_name) { InternalUtils::SetTransformerMaskid(&config, tensor_name); + }) + .def_static("disable_tensorrt_half_ops", + [](paddle_infer::Config &config, + const std::unordered_set &ops) { + InternalUtils::DisableTensorRtHalfOps(&config, ops); }); } } // namespace diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index d28dc9bec40088..489b25f35867c8 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -289,8 +289,8 @@ std::string CastPyArg2String(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { if (PyObject_CheckString(obj)) { - Py_ssize_t size; - const char* data; + Py_ssize_t size = 0; + const char* data = nullptr; data = PyUnicode_AsUTF8AndSize(obj, &size); return std::string(data, (size_t)size); // NOLINT } else { @@ -696,8 +696,8 @@ std::vector CastPyArg2Strings(PyObject* obj, for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckString(item)) { - Py_ssize_t size; - const char* data; + Py_ssize_t size = 0; + const char* data = nullptr; data = PyUnicode_AsUTF8AndSize(item, &size); value.emplace_back(std::string(data, (size_t)size)); // NOLINT } else { @@ -716,8 +716,8 @@ std::vector CastPyArg2Strings(PyObject* obj, for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckString(item)) { - Py_ssize_t size; - const char* data; + Py_ssize_t size = 0; + const char* data = nullptr; data = PyUnicode_AsUTF8AndSize(item, &size); value.emplace_back(std::string(data, (size_t)size)); // NOLINT } else { @@ -896,8 +896,8 @@ void ConstructAttrMapFromPyArgs( PyObject* obj = nullptr; for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { VLOG(1) << "Start Process " << arg_pos; - Py_ssize_t key_len; - const char* key_ptr; + Py_ssize_t key_len = 0; + const char* key_ptr = nullptr; obj = PyTuple_GET_ITEM(args, arg_pos); if (PyObject_CheckString(obj)) { key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len); @@ -988,8 +988,8 @@ void ConstructAttrMapForRunProgram( PyObject* obj = nullptr; for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) { VLOG(1) << "Start Process " << arg_pos; - Py_ssize_t key_len; - const char* key_ptr; + Py_ssize_t key_len = 0; + const char* key_ptr = nullptr; obj = PyTuple_GET_ITEM(args, arg_pos); if (PyObject_CheckString(obj)) { key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len); diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index 9ba115381a2c00..5b8d169d91f746 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -264,21 +264,21 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - data = static.data(name="x", shape=[None, 1], dtype="float32") - hidden = static.nn.fc(data, size=10) - loss = paddle.mean(hidden) - paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) + >>> data = static.data(name="x", shape=[None, 1], dtype="float32") + >>> hidden = static.nn.fc(data, size=10) + >>> loss = paddle.mean(hidden) + >>> paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) - build_strategy = static.BuildStrategy() - build_strategy.enable_inplace = True - build_strategy.memory_optimize = True - build_strategy.reduce_strategy = static.BuildStrategy.ReduceStrategy.Reduce - program = static.CompiledProgram(static.default_main_program(), build_strategy=build_strategy) + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.enable_inplace = True + >>> build_strategy.memory_optimize = True + >>> build_strategy.reduce_strategy = static.BuildStrategy.ReduceStrategy.Reduce + >>> program = static.CompiledProgram(static.default_main_program(), build_strategy=build_strategy) )DOC"); py::enum_(build_strategy, "ReduceStrategy") @@ -316,14 +316,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.reduce_strategy = static.BuildStrategy.ReduceStrategy.Reduce - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.reduce_strategy = static.BuildStrategy.ReduceStrategy.Reduce + )DOC") .def_property( "gradient_scale_strategy", [](const BuildStrategy &self) { return self.gradient_scale_; }, @@ -345,38 +345,38 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import numpy - import paddle - import paddle.static as static - - paddle.enable_static() - - use_cuda = True - place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() - exe = static.Executor(place) - - data = static.data(name='X', shape=[None, 1], dtype='float32') - hidden = static.nn.fc(data, size=10) - loss = paddle.mean(hidden) - paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) - - exe.run(static.default_startup_program()) - - build_strategy = static.BuildStrategy() - build_strategy.gradient_scale_strategy = \ - static.BuildStrategy.GradientScaleStrategy.Customized - compiled_prog = static.CompiledProgram( - static.default_main_program(), - build_strategy=build_strategy, - ) - - x = numpy.random.random(size=(10, 1)).astype('float32') - loss_grad = numpy.ones((1)).astype("float32") * 0.01 - loss_grad_name = loss.name+"@GRAD" - loss_data = exe.run(compiled_prog, - feed={"X": x, loss_grad_name : loss_grad}, - fetch_list=[loss.name, loss_grad_name]) - )DOC") + >>> import numpy + >>> import paddle + >>> import paddle.static as static + + >>> paddle.enable_static() + + >>> use_cuda = paddle.device.is_compiled_with_cuda + >>> place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + >>> exe = static.Executor(place) + + >>> data = static.data(name='X', shape=[None, 1], dtype='float32') + >>> hidden = static.nn.fc(data, size=10) + >>> loss = paddle.mean(hidden) + >>> paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) + + >>> exe.run(static.default_startup_program()) + + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.gradient_scale_strategy = \ + ... static.BuildStrategy.GradientScaleStrategy.Customized + >>> compiled_prog = static.CompiledProgram( + ... static.default_main_program(), + ... build_strategy=build_strategy, + >>> ) + + >>> x = numpy.random.random(size=(10, 1)).astype('float32') + >>> loss_grad = numpy.ones((1)).astype("float32") * 0.01 + >>> loss_grad_name = loss.name+"@GRAD" + >>> loss_data = exe.run(compiled_prog, + ... feed={"X": x, loss_grad_name : loss_grad}, + ... fetch_list=[loss.name, loss_grad_name]) + )DOC") .def_property( "debug_graphviz_path", [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, @@ -395,14 +395,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.debug_graphviz_path = "./graph" - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.debug_graphviz_path = "./graph" + )DOC") .def_property( "enable_sequential_execution", [](const BuildStrategy &self) { @@ -422,13 +422,13 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.enable_sequential_execution = True + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.enable_sequential_execution = True )DOC") .def_property( "remove_unnecessary_lock", @@ -449,13 +449,13 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.remove_unnecessary_lock = True + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.remove_unnecessary_lock = True )DOC") .def_property( "num_trainers", @@ -525,16 +525,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Default False. Examples: - .. code-block:: python - - import paddle - import paddle.static as static - - paddle.enable_static() + .. code-block:: python - build_strategy = static.BuildStrategy() - build_strategy.build_cinn_pass = True - )DOC") + >>> import paddle + >>> import paddle.static as static + >>> paddle.enable_static() + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.build_cinn_pass = True + )DOC") .def_property( "fuse_elewise_add_act_ops", [](const BuildStrategy &self) { @@ -555,14 +553,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_elewise_add_act_ops = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_elewise_add_act_ops = True + )DOC") .def_property( "fuse_gemm_epilogue", [](const BuildStrategy &self) { return self.fuse_gemm_epilogue_; }, @@ -581,14 +579,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_gemm_epilogue = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_gemm_epilogue = True + )DOC") .def_property( "fuse_adamw", [](const BuildStrategy &self) { return self.fuse_adamw_; }, @@ -605,12 +603,13 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT it may make the execution faster. Default is False. Examples: .. code-block:: python - import paddle - import paddle.static as static - paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_adamw = True - )DOC") + + >>> import paddle + >>> import paddle.static as static + >>> paddle.enable_static() + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_adamw = True + )DOC") .def_property( "fused_attention", [](const BuildStrategy &self) { return self.fused_attention_; }, @@ -629,14 +628,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fused_attention = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fused_attention = True + )DOC") .def_property( "fused_feedforward", [](const BuildStrategy &self) { return self.fused_feedforward_; }, @@ -655,14 +654,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fused_feedforward = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fused_feedforward = True + )DOC") .def_property( "sequential_run", [](const BuildStrategy &self) { return self.sequential_run_; }, @@ -680,14 +679,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.sequential_run = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.sequential_run = True + )DOC") .def_property( "fuse_bn_act_ops", [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, @@ -706,14 +705,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_bn_act_ops = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_bn_act_ops = True + )DOC") .def_property( "fuse_bn_add_act_ops", [](const BuildStrategy &self) { return self.fuse_bn_add_act_ops_; }, @@ -732,14 +731,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_bn_add_act_ops = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_bn_add_act_ops = True + )DOC") .def_property( "enable_auto_fusion", [](const BuildStrategy &self) { return self.enable_auto_fusion_; }, @@ -759,14 +758,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.enable_auto_fusion = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.enable_auto_fusion = True + )DOC") .def_property( "fuse_relu_depthwise_conv", [](const BuildStrategy &self) { @@ -789,13 +788,13 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_relu_depthwise_conv = True + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_relu_depthwise_conv = True )DOC") .def_property( "fuse_broadcast_ops", @@ -819,16 +818,15 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT for NCCLReduce operations for a period of time. Default False. Examples: - .. code-block:: python - - import paddle - import paddle.static as static + .. code-block:: python - paddle.enable_static() + >>> import paddle + >>> import paddle.static as static + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.fuse_broadcast_ops = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.fuse_broadcast_ops = True + )DOC") .def_property( "fuse_all_optimizer_ops", [](const BuildStrategy &self) { @@ -864,14 +862,14 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.sync_batch_norm = True - )DOC") + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.sync_batch_norm = True + )DOC") .def_property( "memory_optimize", [](const BuildStrategy &self) -> py::object { @@ -904,15 +902,15 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT Examples: .. code-block:: python - import paddle - import paddle.static as static + >>> import paddle + >>> import paddle.static as static - paddle.enable_static() + >>> paddle.enable_static() - build_strategy = static.BuildStrategy() - build_strategy.memory_optimize = True + >>> build_strategy = static.BuildStrategy() + >>> build_strategy.memory_optimize = True - )DOC") + )DOC") .def_property( "is_distribution", [](const BuildStrategy &self) { return self.is_distribution_; }, diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index f32fe6f592218d..3e50bd64ca4ac0 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -25,6 +25,7 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" @@ -46,6 +47,7 @@ #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/utils/flags.h" #include "pybind11/stl.h" namespace py = pybind11; @@ -66,6 +68,8 @@ using pybind11::return_value_policy; USE_PASS(dead_code_elimination); USE_PASS(inplace); +PHI_DECLARE_bool(print_ir); + namespace paddle { namespace pybind { @@ -340,6 +344,14 @@ void BindOperation(py::module *m) { }); } +py::str Value2String(const Value &self) { + std::ostringstream print_stream; + print_stream << "Value("; + print_stream << GetValueInfo(self); + print_stream << ")"; + return print_stream.str(); +} + void BindValue(py::module *m) { py::class_ value(*m, "Value", R"DOC( Value class represents the SSA value in the IR system. It is a directed edge @@ -363,6 +375,10 @@ void BindValue(py::module *m) { .def("first_use", &Value::first_use, return_value_policy::reference) .def("has_one_use", &Value::HasOneUse) .def("use_empty", &Value::use_empty) + .def("replace_all_uses_with", + [](Value &self, Value &op_value) { + self.ReplaceAllUsesWith(op_value); + }) .def("__eq__", &Value::operator==) .def("__eq__", [](Value &self, OpResult &other) { @@ -370,13 +386,8 @@ void BindValue(py::module *m) { }) .def("__hash__", [](const Value &self) { return std::hash{}(self); }) - .def("__str__", [](const Value &self) -> py::str { - std::ostringstream print_stream; - print_stream << "Value("; - print_stream << GetValueInfo(self); - print_stream << ")"; - return print_stream.str(); - }); + .def("__str__", &Value2String) + .def("__repr__", &Value2String); } void BindOpOperand(py::module *m) { @@ -457,6 +468,16 @@ phi::DataType GetOpResultDtype(const OpResult &result) { } } +const phi::DDim &GetOpResultDims(const OpResult &result) { + if (result.type().isa()) { + return result.type().dyn_cast().dims(); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, we can only get shape for dense " + "tensor.")); + } +} + #define OVERRIDE_OPERATOR(operator, api, other_type) \ op_result.def(#operator, [](OpResult &self, other_type other) { \ return paddle::dialect::api(self, other); \ @@ -610,6 +631,12 @@ void BindOpResult(py::module *m) { return false; } }) + .def("numel", + [](OpResult &self) { return phi::product(GetOpResultDims(self)); }) + .def("replace_all_uses_with", + [](OpResult &self, OpResult &op_result) { + self.ReplaceAllUsesWith(op_result); + }) .def_property( "stop_gradient", [](OpResult &self) { @@ -638,16 +665,7 @@ void BindOpResult(py::module *m) { }) .def_property( "shape", - [](OpResult &self) { - if (self.type().isa()) { - return phi::vectorize( - self.type().dyn_cast().dims()); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Currently, we can only get shape for dense " - "tensor.")); - } - }, + [](OpResult &self) { return phi::vectorize(GetOpResultDims(self)); }, [](OpResult &self, const std::vector &shape) { PADDLE_THROW(phi::errors::InvalidArgument( "can't set shape when building static graph")); @@ -1038,12 +1056,15 @@ SplitedResult ForwardBackwardSplit( VLOG(4) << "forward_value_map.size() is " << forward_value_map.size(); VLOG(4) << "backward_value_map.size() is " << backward_value_map.size(); - std::ostringstream print_stream; - print_stream << "ForwardProgram is :\n"; - forward_program->Print(print_stream); - print_stream << "BackwardProgram is:\n"; - backward_program->Print(print_stream); - VLOG(4) << "Splited Program (fwd | bwd): \n" << print_stream.str(); + if (FLAGS_print_ir) { + std::ostringstream print_stream; + print_stream << "ForwardProgram is :\n"; + forward_program->Print(print_stream); + print_stream << "BackwardProgram is:\n"; + backward_program->Print(print_stream); + std::cout << "Splited Program (fwd | bwd): \n" + << print_stream.str() << std::endl; + } // construct all attributes we needed. @@ -1138,7 +1159,7 @@ void BindUtils(pybind11::module *m) { y_s = paddle.matmul(x_s, x_s) z_s = paddle.add(y_s, y_s) k_s = paddle.tanh(z_s) - newir_program = ir.translate_to_new_ir(main_program.desc) + newir_program = pir.translate_to_new_ir(main_program.desc) print(newir_program) @@ -1158,6 +1179,53 @@ void BindUtils(pybind11::module *m) { Returns: list[str] : List of unregistered operators in paddle dialect, the name is expressed by origin op name. )DOC"); + m->def( + "translate_to_new_ir_with_param_map", + [](const framework::ProgramDesc &legacy_program) { + auto ir_ctx = pir::IrContext::Instance(); + auto program = std::make_shared(ir_ctx); + translator::ProgramTranslator program_translator(&legacy_program, + program.get()); + program_translator.Translate(); + return std::make_pair(program, program_translator.VarDesc2Value()); + }, + R"DOC( + Convert Fluid Program to New IR Program and get the mappings of VarDesc -> pir::Value. + + Args: + + legacy_program (ProgramDesc): The Fluid Program that will be converted. + + Returns: + Program: The New IR Program + dict[str, pir::Value]: Mapping between VarDesc(by name) and pir::Value. + + Raises: + PreconditionNotMet: If legacy_program has multi block will raise error. + + Examples: + .. code-block:: python + + import paddle + from paddle import pir + paddle.enable_static() + + x = paddle.randn([4, 4]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + x_s.stop_gradient = False + y_s = paddle.matmul(x_s, x_s) + z_s = paddle.add(y_s, y_s) + k_s = paddle.tanh(z_s) + newir_program, mappings = pir.translate_to_new_ir_with_param_map(main_program.desc) + + print(newir_program) + print(mappings) + )DOC"); } void BindIrPass(pybind11::module *m) { diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index 1b0101a85537ad..1c4315e8ee1851 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -459,6 +459,7 @@ void BindPlace(pybind11::module &m) { // NOLINT py::enum_(m, "XPUVersion", py::arithmetic()) .value("XPU1", phi::backends::xpu::XPUVersion::XPU1) .value("XPU2", phi::backends::xpu::XPUVersion::XPU2) + .value("XPU3", phi::backends::xpu::XPUVersion::XPU3) .export_values(); m.def("get_xpu_device_count", platform::GetXPUDeviceCount); m.def("get_xpu_device_version", diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 19e813cc25e7a7..dcae0104f35598 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -195,6 +195,7 @@ limitations under the License. */ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/fluid/imperative/layout_autotune.h" +#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h" @@ -766,6 +767,42 @@ void BindVjp(pybind11::module *m) { out (bool): True means that the op has custom vjp rules, False means it does not. )DOC"); } + +void BindDecomp(pybind11::module *m) { + m->def("call_decomp", [](pir::Operation &fwd_op) { + py::list res; + paddle::dialect::DecompInterface decomp_interface = + fwd_op.dyn_cast(); + PADDLE_ENFORCE( + decomp_interface, + phi::errors::InvalidArgument( + "The decomp function is not registered in %s op ", fwd_op.name())); + std::vector> decomp_res = + decomp_interface.Decomp(&fwd_op); + for (size_t i = 0; i < decomp_res.size(); ++i) { + py::list sub_res; + for (size_t j = 0; j < decomp_res[i].size(); ++j) { + if (!decomp_res[i][j]) { + sub_res.append(nullptr); + } else { + sub_res.append(decomp_res[i][j]); + } + } + res.append(sub_res); + } + return res; + }); + + m->def("has_decomp", [](pir::Operation &fwd_op) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); + auto decomp_interface_impl = + fwd_op_info.GetInterfaceImpl(); + if (decomp_interface_impl == nullptr) return false; + return true; + }); +} + PYBIND11_MODULE(libpaddle, m) { BindImperative(&m); BindEager(&m); @@ -852,7 +889,7 @@ PYBIND11_MODULE(libpaddle, m) { m.def("clear_gradients", [](std::vector> param_list, bool set_to_zero) { - for (auto param : param_list) { + for (auto const ¶m : param_list) { param->ClearGradient(set_to_zero); } }); @@ -2940,6 +2977,7 @@ All parameter, weight, gradient are variables in Paddle. BindPIR(&m); BindVjp(&m); + BindDecomp(&m); } } // namespace pybind } // namespace paddle diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 3098145b801c72..71257dc588dac1 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -566,11 +566,7 @@ std::vector SetKernelDistOutput( if (tmp) { // TODO(GhostScreaming): now all dist case are nullptr if (tmp->impl() == nullptr) { - phi::DenseTensor dense_t; - // TODO(GhostScreaming): polish code, dist_attr is null now - phi::distributed::TensorDistAttr dist_attr; - auto dist_t = - std::make_shared(dense_t, dist_attr); + auto dist_t = std::make_shared(); tmp->set_impl(dist_t); } result.emplace_back( @@ -587,11 +583,7 @@ std::vector SetKernelDistOutput( out->reserve(out_size); std::vector results(out_size); for (size_t i = 0; i < out_size; ++i) { - phi::DenseTensor dense_t; - // TODO(GhostScreaming): polish code, dist_attr is null now - phi::distributed::TensorDistAttr dist_attr; - auto dist_t = - std::make_shared(dense_t, dist_attr); + auto dist_t = std::make_shared(); results[i] = dist_t.get(); out->emplace_back(); out->back().set_impl(dist_t); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 6fd1ddf87c4a25..8ba76b64f5f7af 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -211,7 +211,7 @@ inline phi::DenseTensor TransDataPlace(const phi::DenseTensor& tensor, // But the embarrassment is that this solution this solution makes training // slower. phi::DenseTensor out; - phi::DeviceContext* dev_ctx; + phi::DeviceContext* dev_ctx = nullptr; if (dst_place.GetType() != AllocationType::CPU) { dev_ctx = pool.Get(dst_place); } else { @@ -785,7 +785,8 @@ PrepareDataForDistTensor(const std::vector& input, // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; out.push_back(std::make_shared( - trans_in_tensor, dist_tensor->dist_attr())); + std::make_shared(trans_in_tensor), + dist_tensor->dist_attr())); } } else { out.push_back(nullptr); diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index a69afdfdfe2d8c..4984974b338ef1 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -11,9 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "paddle/phi/api/lib/kernel_dispatch.h" - +#include #ifdef _MSC_VER #include #endif @@ -68,6 +67,7 @@ BackendSet GetTensorBackendSet(const phi::TensorBase& t) { #endif phi::Backend backend_key = phi::TransToPhiBackend(t.place()); BackendSet backend_set(backend_key); + VLOG(10) << "update BackendSet by tensor: add [" << backend_key << "]"; if (backend_key == Backend::GPU && phi::DenseTensor::classof(&t) && static_cast(t).meta().use_gpudnn) { backend_set = backend_set | BackendSet(Backend::GPUDNN); diff --git a/paddle/phi/api/lib/kernel_dispatch.h b/paddle/phi/api/lib/kernel_dispatch.h index 847c2a7d14756e..7bd3524ed795c3 100644 --- a/paddle/phi/api/lib/kernel_dispatch.h +++ b/paddle/phi/api/lib/kernel_dispatch.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once +#include #include #include #include - #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/lib/backend_set.h" #include "paddle/phi/api/lib/data_type_set.h" @@ -99,11 +99,13 @@ struct KernelKeyParser : ArgsIterator { inline void AssignKernelKeySet(const phi::TensorBase& tensor) { // assign Backend BackendSet tensor_backend_set = detail::GetTensorBackendSet(tensor); + VLOG(8) << "Get BackendSet from tensor"; key_set.backend_set = key_set.backend_set | tensor_backend_set; // tensor's attribute use_gpudnn=False, explicitly disable gpudnn kernel if (tensor_backend_set == BackendSet(Backend::GPU) || disable_gpudnn) { disable_gpudnn = true; key_set.backend_set = key_set.backend_set - BackendSet(Backend::GPUDNN); + VLOG(8) << "Disable kernel backend: GPUDNN"; } // assign DataLayout phi::DataLayout tensor_layout = tensor.layout(); @@ -115,6 +117,7 @@ struct KernelKeyParser : ArgsIterator { auto promote_result = PromoteTypes(dtype_set); if (promote_result != DataType::UNDEFINED) { key_set.dtype = promote_result; + VLOG(8) << "promote kernel DataType:" << promote_result; } } diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 5d0e2f139c2137..da8b9125a71ddd 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -121,7 +121,7 @@ void CustomOpKernelContext::EmplaceBackAttr(paddle::any attr) { void CustomOpKernelContext::EmplaceBackAttrs( const std::vector& attrs) { - attrs_ = std::move(attrs); + attrs_ = attrs; } const Tensor& CustomOpKernelContext::InputAt(size_t idx) const { diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 1a57a578c78972..f50347fd6678aa 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -95,7 +95,7 @@ Tensor::Tensor(const Place &place, const std::vector &shape) { Tensor::Tensor(std::shared_ptr tensor_impl, const std::string &name) - : impl_(std::move(tensor_impl)), name_(std::move(name)) {} + : impl_(std::move(tensor_impl)), name_(name) {} /* Part 2: Dimension, DataType and DataLayout methods */ diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 7be497318443a7..5e39b764fa96d7 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1817,6 +1817,7 @@ infer_meta : func : UnchangedInferMeta param : [out] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : relu_grad backward: relu_double_grad @@ -2234,6 +2235,7 @@ infer_meta : func : UnchangedInferMeta param : [x] + spmd_rule : ElementwiseUnaryGradInferSpmd kernel : func : square_grad backward : square_double_grad diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 00189d880e67fb..3bd51c35e5d153 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -171,6 +171,12 @@ # Dist Branch will not generated in the API that doesn't have input tensor. SET_SINGLE_OUT_REPLICATED_DIST_ATTR = """ SetReplicatedDistAttrForOutput({}, spmd_info.first[0].process_mesh());""" +SET_VECTOR_OUT_REPLICATED_DIST_ATTR = """ + auto current_process_mesh = spmd_info.first[0].process_mesh(); + for (size_t i = 0; i < dist_out.size(); ++i) {{ + SetReplicatedDistAttrForOutput(dist_out[i], current_process_mesh); + }} +""" # 4. Select Kernel KERNEL_SELECTION_TEMPLATE = """ @@ -680,6 +686,10 @@ def generate_infer_global_shape_code(self) -> str: name=out_name ) output_args_code += f"{out_name}_meta_ptr_vec, " + if self.generate_general_infer_spmd is True: + set_out_dist_attr_code += ( + SET_VECTOR_OUT_REPLICATED_DIST_ATTR + ) else: output_decl_code += SINGLE_GLOBAL_META_OUT_DECL_TEMPLATE.format( out_name, out_name diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index d95bc19c57bff2..47eda81f5d0ca0 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -407,8 +407,8 @@ composite : minimum_grad(x, y, out_grad, axis, x_grad, y_grad) - backward_op : mish_grad - forward : mish (Tensor x, float threshold) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float threshold) + forward : mish (Tensor x, float lambda) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float lambda) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -747,7 +747,7 @@ kernel : func : tile_grad no_need_buffer : x - composite : tile_grad(x, outgrad, repeat_times, x_grad) + composite : tile_grad(x, out_grad, repeat_times, x_grad) backward : tile_double_grad - backward_op : trans_layout_grad diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index a21cb39ac076f6..f7d3878e44847f 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1316,6 +1316,18 @@ dropout1_out: Dropout1Out dropout2_out: Dropout2Out +- op : fused_gemm_epilogue + inputs: + {x : X, y : Y, bias : Bias} + outputs : + {out : Out, reserve_space: ReserveSpace} + +- op : fused_gemm_epilogue_grad + inputs: + {x : X, y : Y, reserve_space: ReserveSpace, out_grad : DOut} + outputs : + {x_grad : DX, y_grad : DY, bias_grad : DBias} + - op : fused_transpose extra : attrs : [str data_format = "AnyLayout"] @@ -2173,6 +2185,17 @@ outputs : out : Out +- op : pad + backward : pad_grad, pad_double_grad + inputs : + x : X + outputs : + out : Out + scalar: + pad_value: + data_type : float + support_tensor : true + - op : pad2d backward : pad2d_grad extra : @@ -3215,6 +3238,12 @@ outputs: {out: Out} +- op: dpsgd + inputs: + {param: Param,grad: Grad,learning_rate: LearningRate} + outputs: + param_out : ParamOut + - op: fetch (fetch_v2) inputs: {x: X} outputs: {out: Out} diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index a6ab3c3ec954f0..aaf6c4e1445ef4 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -134,7 +134,7 @@ backward : angle_grad - op : argmax - args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, int dtype = 3) + args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, DataType dtype = DataType::INT64) output : Tensor(out) infer_meta : func : ArgMinMaxInferMeta @@ -143,7 +143,7 @@ data_type : x - op : argmin - args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, int dtype = 3) + args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, DataType dtype = DataType::INT64) output : Tensor(out) infer_meta : func : ArgMinMaxInferMeta @@ -2073,6 +2073,7 @@ output : Tensor(out) infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : relu inplace : (x -> out) @@ -2458,6 +2459,7 @@ output : Tensor infer_meta : func : UnchangedInferMeta + spmd_rule : ElementwiseUnaryInferSpmd kernel : func : square {dense -> dense}, square_sr {selected_rows -> selected_rows} diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index 1c916682cf7b1c..55e629de34e7e2 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -20,8 +20,14 @@ endif() if(WITH_XPU) list(APPEND BACKENDS_SRCS xpu/xpu_context.cc xpu/xpu_info.cc) - list(APPEND BACKENDS_SRCS xpu/xpu_op_list.cc xpu/xpu1_op_list.cc - xpu/xpu2_op_list.cc xpu/xpu_l3_strategy.cc) + list( + APPEND + BACKENDS_SRCS + xpu/xpu_op_list.cc + xpu/xpu1_op_list.cc + xpu/xpu2_op_list.cc + xpu/xpu3_op_list.cc + xpu/xpu_l3_strategy.cc) list(APPEND BACKENDS_DEPS phi_dynload_xpti) endif() diff --git a/paddle/phi/backends/context_pool.cc b/paddle/phi/backends/context_pool.cc index 619db6f83fc240..7824fc3b160b10 100644 --- a/paddle/phi/backends/context_pool.cc +++ b/paddle/phi/backends/context_pool.cc @@ -61,7 +61,7 @@ thread_local const std::map>>* - ptr; + ptr = nullptr; if (external_device_contexts_ && external_device_contexts_->count(place)) { ptr = external_device_contexts_; } else { diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index 24ad5087769de5..748c80c0859c5e 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -494,7 +494,7 @@ std::vector DeviceManager::GetSelectedDeviceList( auto FLAGS_selected_devices = getenv(FLAGS.c_str()); if (FLAGS_selected_devices) { auto devices_str = paddle::string::Split(FLAGS_selected_devices, ','); - for (auto id : devices_str) { + for (auto const& id : devices_str) { device_list.push_back(atoi(id.c_str())); } } else { @@ -697,8 +697,8 @@ DeviceManager& DeviceManager::Instance() { } void DeviceManager::Release() { - stream::Stream::ReleaseAll(); event::Event::ReleaseAll(); + stream::Stream::ReleaseAll(); Instance().device_map_.clear(); Instance().device_impl_map_.clear(); } diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index 6989f32b18e9e0..bdb9e120d2884b 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -185,9 +185,9 @@ static inline std::string join(const std::string& part1, static inline std::vector split( const std::string& str, const std::string separator = " ") { std::vector str_list; - std::string::size_type firstPos; + std::string::size_type firstPos = 0; firstPos = str.find_first_not_of(separator, 0); - std::string::size_type lastPos; + std::string::size_type lastPos = 0; lastPos = str.find_first_of(separator, firstPos); while (std::string::npos != firstPos && std::string::npos != lastPos) { str_list.push_back(str.substr(firstPos, lastPos - firstPos)); @@ -263,7 +263,7 @@ static inline void* GetDsoHandleFromSearchPath( #endif // !_WIN32 std::vector dso_names = split(dso_name, ";"); void* dso_handle = nullptr; - for (auto dso : dso_names) { + for (auto const& dso : dso_names) { // 1. search in user config path by FLAGS dso_handle = GetDsoHandleFromSpecificPath(config_path, dso, dynload_flags); // 2. search in system default path @@ -272,7 +272,7 @@ static inline void* GetDsoHandleFromSearchPath( } // 3. search in extra paths if (nullptr == dso_handle) { - for (auto path : extra_paths) { + for (auto const& path : extra_paths) { VLOG(3) << "extra_paths: " << path; dso_handle = GetDsoHandleFromSpecificPath(path, dso, dynload_flags); } diff --git a/paddle/phi/backends/event.cc b/paddle/phi/backends/event.cc index 1c620afbad558d..c08b4b269b2d2e 100644 --- a/paddle/phi/backends/event.cc +++ b/paddle/phi/backends/event.cc @@ -46,6 +46,7 @@ Event::Event(const Place& place, event_t event) own_data_(false) {} Event::~Event() { + Synchronize(); Destroy(); std::unique_lock lock(g_events_mutex); g_events.remove(this); @@ -77,14 +78,35 @@ void Event::Destroy() { own_data_ = false; event_ = nullptr; device_ = nullptr; + is_recorded_ = false; } } -void Event::Record(const stream::Stream* stream) { stream->RecordEvent(this); } +void Event::Record(const stream::Stream* stream) { + if (device_) { + is_recorded_ = true; // synchronize the event during detroy + stream->RecordEvent(this); + } +} -bool Event::Query() const { return device_->QueryEvent(this); } +bool Event::Query() const { + if (device_ && is_recorded_) { + bool ret = device_->QueryEvent(this); + if (ret) { + is_recorded_ = + false; // event completed, do not need to synchronize the event. + } + return ret; + } else { + return true; + } +} -void Event::Synchronize() const { device_->SynchronizeEvent(this); } +void Event::Synchronize() const { + if (device_ && is_recorded_) { + device_->SynchronizeEvent(this); + } +} const Place& Event::GetPlace() const { return place_; } diff --git a/paddle/phi/backends/event.h b/paddle/phi/backends/event.h index 1dac619c2abf96..21dc9f47d7b89e 100644 --- a/paddle/phi/backends/event.h +++ b/paddle/phi/backends/event.h @@ -57,6 +57,7 @@ class Event { Device* device_; event_t event_; bool own_data_ = true; + mutable bool is_recorded_ = false; }; } // namespace event diff --git a/paddle/phi/backends/gpu/cuda/cudnn_helper.h b/paddle/phi/backends/gpu/cuda/cudnn_helper.h index 651a4247a12df0..74db3fc75bcd10 100644 --- a/paddle/phi/backends/gpu/cuda/cudnn_helper.h +++ b/paddle/phi/backends/gpu/cuda/cudnn_helper.h @@ -33,8 +33,12 @@ namespace phi { namespace backends { namespace gpu { +#define CUDNN_VERSION_COMPUTE(major, minor, patch) \ + ((major) <= 8 ? (major)*1000 + (minor)*100 + (patch) \ + : (major)*10000 + (minor)*100 + (patch)) + #define CUDNN_VERSION_MIN(major, minor, patch) \ - (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) + (CUDNN_VERSION >= CUDNN_VERSION_COMPUTE(major, minor, patch)) enum class DataLayout { // Not use kNHWC, diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index d6ce3e750f65ff..7905320728bda5 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -686,7 +686,7 @@ struct GPUContext::Impl { void AddStreamCallback(const std::function& callback) const { // NOTE(zhiqiu): better use threadpool here, otherwise "std::async" may // launch too many threads and result in thread oversubscription. - auto* callback_func = new std::function(std::move(callback)); + auto* callback_func = new std::function(callback); auto* func = new std::function([this, callback_func] { std::lock_guard lock(stream_call_back_mtx_); VLOG(4) << "Stream callback"; diff --git a/paddle/phi/backends/gpu/gpu_info.cc b/paddle/phi/backends/gpu/gpu_info.cc index f6ca9d4168b2c8..1849faa4520774 100644 --- a/paddle/phi/backends/gpu/gpu_info.cc +++ b/paddle/phi/backends/gpu/gpu_info.cc @@ -47,7 +47,7 @@ std::vector GetSelectedDevices() { std::vector devices; if (!FLAGS_selected_gpus.empty()) { auto devices_str = Split(FLAGS_selected_gpus, ','); - for (auto id : devices_str) { + for (auto const& id : devices_str) { devices.push_back(atoi(id.c_str())); } } else { diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index a447df94cb4dc1..a29b5e110922a4 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -146,19 +146,40 @@ void InitGpuProperties(Place place, } #else size_t cudnn_dso_ver = dynload::cudnnGetVersion(); + auto get_cudnn_major = [](auto version) { + if (version < 9000) { + return version / 1000; + } + // CUDNN changes the CUDNN_VERSION rules after 9.0 + return version / 10000; + }; + auto get_cudnn_minor = [](auto version) { + if (version < 9000) { + return (version % 1000) / 100; + } + // CUDNN changes the CUDNN_VERSION rules after 9.0 + return (version % 10000) / 100; + }; + LOG_FIRST_N(WARNING, 1) << "device: " << static_cast(place.device) - << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." - << (cudnn_dso_ver % 1000) / 100 << "."; + << ", cuDNN Version: " + << get_cudnn_major(cudnn_dso_ver) << "." + << get_cudnn_minor(cudnn_dso_ver) << "."; // Check CUDA/CUDNN version compatiblity auto local_cuda_version = (*driver_version / 1000) * 10 + (*driver_version % 100) / 10; auto compile_cuda_version = (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10; + + // Compute cuDNN major + auto local_cudnn_major = get_cudnn_major(cudnn_dso_ver); + size_t compile_cudnn_major = CUDNN_MAJOR; + #if defined(__linux__) PADDLE_ENFORCE_EQ( (local_cuda_version / 10 < compile_cuda_version / 10) && - (cudnn_dso_ver / 1000 < CUDNN_VERSION / 1000), + (local_cudnn_major < compile_cudnn_major), false, phi::errors::InvalidArgument( "The installed Paddle is compiled with CUDA%d/cuDNN%d," @@ -167,9 +188,9 @@ void InitGpuProperties(Place place, "Please recompile or reinstall Paddle with compatible CUDA/cuDNN " "version.", compile_cuda_version / 10, - CUDNN_VERSION / 1000, + compile_cudnn_major, local_cuda_version / 10, - cudnn_dso_ver / 1000)); + local_cudnn_major)); #endif if (local_cuda_version < compile_cuda_version) { LOG_FIRST_N(WARNING, 1) @@ -269,15 +290,17 @@ void InitDnnHandle(dnnHandle_t* handle, gpuStream_t stream, Place place) { PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreate(handle)); PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetStream(*handle, stream)); #else - auto local_cudnn_version = phi::dynload::cudnnGetVersion() / 100; - auto compile_cudnn_version = CUDNN_VERSION / 100; - if (local_cudnn_version < static_cast(compile_cudnn_version)) { + auto version = phi::dynload::cudnnGetVersion(); + auto local_cudnn_major = + (version < 9000) ? version / 1000 : version / 10000; + auto local_cudnn_minor = + (version < 9000) ? (version % 1000) / 100 : (version % 10000) / 100; + if (version < static_cast(CUDNN_VERSION)) { LOG_FIRST_N(WARNING, 1) << "WARNING: device: " << place.device - << ". The installed Paddle is compiled with CUDNN " - << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 - << ", but CUDNN version in your machine is " - << local_cudnn_version / 10 << "." << local_cudnn_version % 10 + << ". The installed Paddle is compiled with CUDNN " << CUDNN_MAJOR + << "." << CUDNN_MINOR << ", but CUDNN version in your machine is " + << local_cudnn_major << "." << local_cudnn_minor << ", which may cause serious incompatible bug. " << "Please recompile or reinstall Paddle with compatible CUDNN " "version."; diff --git a/paddle/phi/backends/gpu/rocm/miopen_helper.h b/paddle/phi/backends/gpu/rocm/miopen_helper.h index b8ce6e22e939be..f7815e2ed851e0 100644 --- a/paddle/phi/backends/gpu/rocm/miopen_helper.h +++ b/paddle/phi/backends/gpu/rocm/miopen_helper.h @@ -61,8 +61,12 @@ inline const char* miopenGetErrorString(miopenStatus_t status) { } // no use, but will have compiling error if not defined +#define CUDNN_VERSION_COMPUTE(major, minor, patch) \ + ((major) <= 8 ? (major)*1000 + (minor)*100 + (patch) \ + : (major)*10000 + (minor)*100 + (patch)) + #define CUDNN_VERSION_MIN(major, minor, patch) \ - (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch))) + (CUDNN_VERSION >= CUDNN_VERSION_COMPUTE(major, minor, patch)) enum class DataLayout { // Not use kNHWC, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc new file mode 100644 index 00000000000000..29a85493958949 --- /dev/null +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -0,0 +1,1028 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include +#include +#include +#include "paddle/phi/backends/xpu/xpu_op_list.h" + +namespace phi { +namespace backends { +namespace xpu { + +XPUOpMap& get_kl3_ops() { + // KL3支持的op,通过op_name, data_type, place来索引 + static XPUOpMap s_xpu3_kernels{ + {"add_act_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"add_layernorm_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"abs_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, + {"adadelta", XPUKernelSet({phi::DataType::FLOAT32})}, + {"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"adam_dense_param_sparse_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"adagrad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"addcmul_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"arg_max", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"argsort_grad", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32})}, + {"argsort", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"assign", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::BOOL})}, + {"assign_value", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL})}, + {"atan", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"atan_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"batch_norm_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"batch_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"bn_act_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"bmm_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"bce_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"beam_search", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"beam_search_decode", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"bilinear_interp_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, + {"bitwise_and", XPUKernelSet({phi::DataType::BOOL})}, + {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_allgather", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"c_allreduce_max", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::INT32})}, + {"c_allreduce_sum", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::INT32})}, + {"c_broadcast", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"c_concat", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_identity", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"c_softmax_with_cross_entropy", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_softmax_with_cross_entropy_grad", + XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_reduce_sum", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"c_split", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::INT32})}, + {"c_sync_calc_stream", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"cast", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"check_finite_and_unscale", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"clip", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"clip_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"coalesce_tensor", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"concat_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"concat", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"conv2d_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv2d", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv1d_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv2d_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv3d_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv3d", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"conv2d_transpose", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"conv2d_transpose_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"cumsum", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"cumsum_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"cumprod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"deformable_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"deformable_conv_v1_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"deformable_conv_v1", XPUKernelSet({phi::DataType::FLOAT32})}, + {"depthwise_conv2d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"depthwise_conv2d", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"depthwise_conv2d_transpose_grad", + XPUKernelSet({phi::DataType::FLOAT32})}, + {"depthwise_conv2d_transpose", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"diag_v2", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"diagonal", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"dropout_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"dropout", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"einsum", XPUKernelSet({phi::DataType::FLOAT32})}, + {"einsum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"elementwise_add_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_add", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"elementwise_div_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_div", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"elementwise_floordiv", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_max_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_max", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_min_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_min", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_mul_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_mul", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"elementwise_pow", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_sub_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"elementwise_sub", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"elementwise_mod", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"embedding_with_eltwise_add_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"empty", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT16, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64})}, + {"embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"embedding_sparse_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"equal", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::BOOL})}, + {"exp_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"exp", XPUKernelSet({phi::DataType::FLOAT32})}, + {"expand_as_v2", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"expand_v2", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"fast_where_xpu", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"layer_norm_act_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fast_layernorm_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fc_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fill", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::FLOAT32})}, + {"fill_any", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::FLOAT32})}, + {"fill_any_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"fill_diagonal_tensor", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"fill_constant", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT16, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::BOOL, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"flatten2_grad", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT32})}, + {"flatten2", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT32})}, + {"flatten_contiguous_range_grad", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"flatten_contiguous_range", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"flatten_grad", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT32})}, + {"flatten", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT32})}, + {"flip", XPUKernelSet({phi::DataType::FLOAT32})}, + {"full_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"fill_constant_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"fused_multi_transformer_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"unfold", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"unfold_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"floor", XPUKernelSet({phi::DataType::FLOAT32})}, + {"gather_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"gather_nd_grad", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32})}, + {"gather_nd", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"gather", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL})}, + {"gaussian_random", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"gelu_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"generate_proposals_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"generate_sequence_xpu", + XPUKernelSet({ + phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + })}, + {"grad_add", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"greater_equal", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"greater_than", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"hard_sigmoid", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"hard_swish_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"huber_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"huber_loss", XPUKernelSet({phi::DataType::FLOAT32})}, + {"kldiv_loss", XPUKernelSet({phi::DataType::FLOAT32})}, + {"kldiv_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"increment", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"index_put", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"index_sample_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"index_sample", + XPUKernelSet({phi::DataType::INT8, + phi::DataType::INT16, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::BOOL})}, + {"index_select_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"index_select", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"instance_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"instance_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"inverse", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, + {"label_smooth", XPUKernelSet({phi::DataType::FLOAT32})}, + {"lars_momentum", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"layer_norm_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"layer_norm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"leaky_relu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"less_equal", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"less_than", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"load", XPUKernelSet({phi::DataType::FLOAT32})}, + {"load_combine", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT8, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"log", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"log_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"log_softmax", XPUKernelSet({phi::DataType::FLOAT32})}, + {"log_softmax_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"logical_and", XPUKernelSet({phi::DataType::BOOL})}, + {"logical_not", XPUKernelSet({phi::DataType::BOOL})}, + {"logical_or", XPUKernelSet({phi::DataType::BOOL})}, + {"logical_xor", XPUKernelSet({phi::DataType::BOOL})}, + {"lookup_table_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"lookup_table_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"masked_select", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"masked_select_grad", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"max_pool2d_with_index", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"max_pool2d_with_index_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"matmul_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"matmul_v2_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"matmul_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"matmul", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"mean_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"mean", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"merged_momentum", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"mish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"mish", XPUKernelSet({phi::DataType::FLOAT32})}, + {"momentum", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"mul", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"mul_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"multiply", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"multi_encoder_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"multiclass_nms3", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nearest_interp_v2", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64})}, + {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nll_loss", XPUKernelSet({phi::DataType::FLOAT32})}, + {"nll_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"not_equal", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, + {"one_hot_v2", + XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, + {"p_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"p_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pad3d_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pad3d", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"pad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT16, + phi::DataType::FLOAT16})}, + {"pad_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT16, + phi::DataType::FLOAT16})}, + {"pixel_shuffle", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pixel_shuffle_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pool2d_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"pool2d", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"pool3d_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"pool3d", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"pow", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})}, + {"prior_box", XPUKernelSet({phi::DataType::FLOAT32})}, + {"prelu_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"prod_raw", XPUKernelSet({phi::DataType::FLOAT32})}, + {"range", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"randperm", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64})}, + {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reciprocal_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"reduce_any", XPUKernelSet({phi::DataType::BOOL})}, + {"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_max", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_mean", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_sum", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT8, + phi::DataType::FLOAT32})}, + {"relu6", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"relu_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"relu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"reshape2_grad", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT32})}, + {"reshape2", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT32})}, + {"reshape", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT32})}, + {"resnet_unit", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"resnet_unit_grad", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"rnn_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"roi_align", XPUKernelSet({phi::DataType::FLOAT32})}, + {"roi_align_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"roll", XPUKernelSet({phi::DataType::FLOAT32})}, + {"roll_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"scale", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"scatter", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32})}, + {"scatter_grad", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"scatter_nd_add", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"sampling_id", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, + {"set_value", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16, + phi::DataType::BOOL})}, + {"set_value_with_tensor", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL})}, + {"set_value_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, + {"sgd_dense_param_sparse_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"silu_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"silu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"size", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::INT16, + phi::DataType::BOOL, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"sigmoid_cross_entropy_with_logits_grad", + XPUKernelSet({phi::DataType::FLOAT32})}, + {"sigmoid_cross_entropy_with_logits", + XPUKernelSet({phi::DataType::FLOAT32})}, + {"shape", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"sigmoid", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sign", XPUKernelSet({phi::DataType::FLOAT32})}, + {"slice_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32})}, + {"slice", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"softmax", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"softmax_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_softmax_mask_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softmax_with_cross_entropy_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"softmax_with_cross_entropy", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"softplus", XPUKernelSet({phi::DataType::FLOAT32})}, + {"softplus_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sparse_coo_tensor", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::UINT8, + phi::DataType::INT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"split", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"split_with_num", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"square_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"square", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"squeeze2_grad", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32})}, + {"squeeze2", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"squeeze", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32})}, + {"squeeze_grad", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32})}, + {"stack", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"stack_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"strided_slice", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT16, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"strided_slice_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT16, + phi::DataType::INT32})}, + {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"take_along_axis", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"tanh_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"tanh", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"temporal_shift", XPUKernelSet({phi::DataType::FLOAT32})}, + {"temporal_shift_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"transfer_dtype", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::BOOL, + phi::DataType::UINT8, + phi::DataType::INT8, + phi::DataType::INT64, + phi::DataType::INT32})}, + {"tril_triu", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"tril", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"triu", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"tril_triu_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"tril_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"triu_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::FLOAT16})}, + {"tile", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::BOOL, + phi::DataType::FLOAT64, + phi::DataType::FLOAT32})}, + {"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"transpose2_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL})}, + {"transpose2", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL})}, + {"transpose_grad", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL})}, + {"transpose", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL})}, + {"truncated_gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})}, + {"top_k", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"top_k_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"update_loss_scaling", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"unbind", XPUKernelSet({phi::DataType::FLOAT32})}, + {"uniform_random", XPUKernelSet({phi::DataType::FLOAT32})}, + {"unique", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"unsqueeze2_grad", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"unsqueeze2", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"unsqueeze_grad", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32})}, + {"unsqueeze", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"unstack", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"unstack_grad", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"warpctc", XPUKernelSet({phi::DataType::FLOAT32})}, + {"where_index", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::FLOAT32})}, + {"where_grad", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT16, + phi::DataType::FLOAT32})}, + {"where", + XPUKernelSet({phi::DataType::INT32, + phi::DataType::INT64, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"sin", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sin_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"cos", XPUKernelSet({phi::DataType::FLOAT32})}, + {"cos_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"linspace", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"randint", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})}, + {"group_norm", XPUKernelSet({phi::DataType::FLOAT32})}, + {"group_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"meshgrid", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64})}, + {"expand_v2_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"isnan_v2", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"yolo_box_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + + // AddMore + {"sequence_conv", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sequence_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"sequence_unpad", XPUKernelSet({phi::DataType::FLOAT32})}, + // Fused op + {"resnet_basic_block_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"resnet_basic_block", XPUKernelSet({phi::DataType::FLOAT32})}, + {"fused_gemm_epilogue", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_gemm_epilogue_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_attention", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_attention_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_feedforward", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_feedforward_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"lod_reset", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, + }; + + return s_xpu3_kernels; +} + +} // namespace xpu +} // namespace backends +} // namespace phi +#endif diff --git a/paddle/phi/backends/xpu/xpu_info.cc b/paddle/phi/backends/xpu/xpu_info.cc index ee51b19f482bff..96ff4cc2c81abb 100644 --- a/paddle/phi/backends/xpu/xpu_info.cc +++ b/paddle/phi/backends/xpu/xpu_info.cc @@ -195,9 +195,12 @@ XPUVersion get_xpu_version(int dev_id) { if (v == K100 || v == K200) { VLOG(1) << "KUNLUN device " << dev_id << " is XPU1\n"; return XPU1; - } else { + } else if (v < KL3_BEGIN) { VLOG(1) << "KUNLUN device " << dev_id << " is XPU2\n"; return XPU2; + } else { + VLOG(1) << "KUNLUN device " << dev_id << " is XPU3\n"; + return XPU3; } } @@ -211,6 +214,9 @@ int get_xpu_max_ptr_size(int dev_id) { case XPUVersion::XPU2: max_ptr_size = 6; break; + case XPUVersion::XPU3: + max_ptr_size = 12; + break; default: PADDLE_THROW(phi::errors::InvalidArgument( "Only support get max ptr size of XPU1 or XPU2.")); diff --git a/paddle/phi/backends/xpu/xpu_info.h b/paddle/phi/backends/xpu/xpu_info.h index b4fbdec7a93613..ad5a0b9745832d 100644 --- a/paddle/phi/backends/xpu/xpu_info.h +++ b/paddle/phi/backends/xpu/xpu_info.h @@ -92,7 +92,7 @@ class XPUDeviceGuard { int prev_id_{-1}; }; -enum XPUVersion { XPU1, XPU2 }; +enum XPUVersion { XPU1, XPU2, XPU3 }; XPUVersion get_xpu_version(int dev_id); int get_xpu_max_ptr_size(int dev_id); diff --git a/paddle/phi/backends/xpu/xpu_op_list.h b/paddle/phi/backends/xpu/xpu_op_list.h index 975a5d02b16b2b..1635ed2e6e8660 100644 --- a/paddle/phi/backends/xpu/xpu_op_list.h +++ b/paddle/phi/backends/xpu/xpu_op_list.h @@ -25,6 +25,7 @@ using XPUOpMap = std::unordered_map; XPUOpMap& get_kl1_ops(); XPUOpMap& get_kl2_ops(); +XPUOpMap& get_kl3_ops(); #ifdef PADDLE_WITH_XPU_KP bool is_xpu_kp_support_op(const std::string& fluid_op_name, diff --git a/paddle/phi/core/ddim.h b/paddle/phi/core/ddim.h index 57ad4d09ef463d..be11b4c9596cd9 100644 --- a/paddle/phi/core/ddim.h +++ b/paddle/phi/core/ddim.h @@ -227,7 +227,7 @@ std::vector vectorize(const DDim& ddim) { return result; } -int64_t product(const DDim& ddim); +TEST_API int64_t product(const DDim& ddim); bool contain_unknown_dim(const DDim& ddim); diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 94d611e8043aa0..8e3e6405f4d29a 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -33,35 +33,45 @@ inline void check_defined(const DistTensor& dist_tensor, method_hint)); } -DistTensor::DistTensor(const phi::DenseTensor& global_value, +DistTensor::DistTensor() : value_(std::make_shared()) {} + +DistTensor::DistTensor(const std::shared_ptr& global_value, const TensorDistAttr& dist_attr) - : dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) { - // TODO(liyurui): This is a temporary solution. We need to support only infer - // meta when the input dense_tensor is empty. - // Support the value in DistTensor only has DenseTensor meta - // but without actual data. So we can visit its meta attr even if it is - // undefined. + : dims_(global_value->dims()), + dist_attr_(dist_attr), + value_(std::make_shared()) { + // If the current rank doesn't in process_mesh, we should create an + // uninitialized tensor only with tensor_meta. if (IsCurRankInMesh(dist_attr.process_mesh())) { - if (value_.initialized() && !dist_attr.is_replicated()) { + if (!dist_attr.is_replicated()) { // 1. create replicated global tensor - int64_t dims_size = global_value.dims().size(); - std::vector dims_mapping(dims_size, -1); - dist_attr_.set_dims_mapping(dims_mapping); - if (dist_attr_.is_partial()) { - dist_attr_.clean_partial_status(); - } - dist_attr_.set_dims_mapping(dims_mapping); + TensorDistAttr replicated_dist_attr(vectorize(global_value->dims())); + replicated_dist_attr.set_process_mesh(dist_attr.process_mesh()); + DistTensor replicated_tensor(global_value, replicated_dist_attr); // 2. reshard from replicated to other state - auto* func = ChooseProperReshardFunction(*this, dist_attr); - auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place()); - func->Eval(dev_ctx, *this, dist_attr, this); + auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr); + auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place()); + func->Eval(dev_ctx, replicated_tensor, dist_attr, this); + } else { + value_ = global_value; + } + } else { + // TODO(liyurui): The following logic is illegal, and should be removed + // later. It exist temporary because the basic execution procedure is not + // ready, even sometimes we try to construct a DistTensor with empty + // DistAttr. Here we warning when the DistAttr is empty for debug use. + if (dist_attr.empty()) { + LOG(WARNING) << "Try to construct a dist tensor with empty dist attr."; } + value_ = global_value; } } DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr) - : dims_(dims), dist_attr_(dist_attr) {} + : dims_(dims), + dist_attr_(dist_attr), + value_(std::make_shared()) {} void DistTensor::unsafe_set_dims(const DDim& dims) { if (this->initialized()) { @@ -80,39 +90,42 @@ void DistTensor::unsafe_set_dist_attr(const TensorDistAttr& dist_attr) { } int64_t DistTensor::numel() const { - check_defined(*this, "numel"); - return value_.numel(); + // DistTensor with uninitialized local tensor can + // also have numel. + return product(dims_); } const DDim& DistTensor::local_dims() const { check_defined(*this, "local_dims"); - return value_.dims(); + return value_->dims(); } bool DistTensor::valid() const { check_defined(*this, "valid"); - return value_.valid(); + return value_->valid(); } -bool DistTensor::defined() const { return value_.holder_ != nullptr; } +bool DistTensor::defined() const { return value_->holder_ != nullptr; } bool DistTensor::initialized() const { - return value_.holder_ != nullptr && value_.holder_->ptr(); + return value_->holder_ != nullptr && value_->holder_->ptr(); } DataType DistTensor::dtype() const { - check_defined(*this, "dtype"); - return value_.dtype(); + // DistTensor with uninitialized local tensor can + // also have dtype. + return value_->dtype(); } DataLayout DistTensor::layout() const { - check_defined(*this, "layout"); - return value_.layout(); + // DistTensor with uninitialized local tensor can + // also have layout. + return value_->layout(); } const Place& DistTensor::place() const { check_defined(*this, "place"); - return value_.holder_->place(); + return value_->holder_->place(); } void* DistTensor::AllocateFrom(Allocator* allocator, diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index c965733a7e0e8e..9e93ccf70c70b8 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -30,12 +30,12 @@ class DistTensor final /// \brief Careful to create dist tensor using default constructor. /// this should only used in reshard for now, and the dist properties /// will be set by reshard later. - DistTensor() = default; + DistTensor(); /// \brief Construct a dist tensor based dense tensor. /// \param global_value The global dense tensor of the current tensor. /// \param dist_attr The distributed attributes of the current tensor. - DistTensor(const phi::DenseTensor& global_value, + DistTensor(const std::shared_ptr& global_value, const TensorDistAttr& dist_attr); /// \brief Construct a empty dist tensor (for infer spmd) @@ -68,7 +68,7 @@ class DistTensor final /// \brief Returns the dense tensor value's const reference in dist tensor. /// \return The DenseTensor value's const reference - const DenseTensor& value() const { return value_; } + const DenseTensor& value() const { return *value_; } /// \brief Returns the mutable dense tensor value in dist tensor. /// \note If DenseTensor value is modified externally, the corresponding @@ -77,7 +77,7 @@ class DistTensor final /// so you need to make sure to consider it thoroughly when using /// this method. /// \return The mutable pointer of DenseTensor value - DenseTensor* unsafe_mutable_value() { return &value_; } + DenseTensor* unsafe_mutable_value() { return value_.get(); } /// \brief Returns the global dims of the dist tensor. /// \return The global dims of the dist tensor. @@ -126,7 +126,7 @@ class DistTensor final // The distributed attributes TensorDistAttr dist_attr_; // The local DenseTensor value - DenseTensor value_; + std::shared_ptr value_; }; } // namespace distributed diff --git a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc index fbeddae16dc408..9d5d8f43f76708 100644 --- a/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.cc @@ -38,10 +38,10 @@ ProcessMesh GetSubProcessMesh(const ProcessMesh& mesh, int64_t axis) { for (int64_t i = 0; i < shape_of_axis; ++i) { coord[axis] = i; int64_t rank = coord.back(); - for (int64_t j = coord.size() - 2; j >= 0; --j) { + for (int64_t j = static_cast(coord.size() - 2); j >= 0; --j) { rank += coord[j] * mesh.dim_size(j + 1); } - process_ids.emplace_back(rank); + process_ids.emplace_back(mesh.process_ids()[rank]); } ProcessMesh out_mesh(shape, process_ids, dim_names); @@ -58,7 +58,8 @@ int64_t FindFirstDiffShardAxis(const TensorDistAttr& in_dist_attr, const auto& out_dims_mapping = out_dist_attr.dims_mapping(); int64_t axis = -1; - for (int64_t i = in_dims_mapping.size() - 1; i >= 0; --i) { + for (int64_t i = static_cast(in_dims_mapping.size() - 1); i >= 0; + --i) { if (in_dims_mapping[i] != out_dims_mapping[i]) { axis = i; break; @@ -87,18 +88,24 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, const DistTensor& in, const TensorDistAttr& out_dist_attr, DistTensor* out) { + VLOG(3) << "Call SameNdMeshReshardFunction Eval"; const auto& in_dist_attr = in.dist_attr(); const auto& process_mesh = out_dist_attr.process_mesh(); int64_t first_diff_axis = FindFirstDiffShardAxis(in_dist_attr, out_dist_attr); + // Backup out_dist_attr to to avoid overwriting the out's dist attr + auto out_dist_attr_orig = out_dist_attr; + SetValue(out, in.value()); SetDistProps(out, in.dims(), in_dist_attr); // 1. change all the partial status to replicated status if needed if (in_dist_attr.is_partial()) { - const auto& in_partial_status = in_dist_attr.partial_status(); - const auto& out_partial_status = out_dist_attr.partial_status(); + // Copy in_dist_attr.partial_status to avoid overwriting the value of + // input when the output and input are the same value + const auto in_partial_status = in_dist_attr.partial_status(); + const auto& out_partial_status = out_dist_attr_orig.partial_status(); for (const auto& kv : in_partial_status) { if (out_partial_status.count(kv.first) != 0) { continue; @@ -173,9 +180,9 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, } // 3. Change replicated to partial - if (out_dist_attr.is_partial()) { + if (out_dist_attr_orig.is_partial()) { const auto& in_partial_status = out->dist_attr().partial_status(); - const auto& out_partial_status = out_dist_attr.partial_status(); + const auto& out_partial_status = out_dist_attr_orig.partial_status(); for (const auto& kv : out_partial_status) { if (in_partial_status.count(kv.first) != 0) { continue; @@ -211,7 +218,7 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, // 4. Change replicated to shard for (int64_t i = first_diff_axis; i >= 0; --i) { - int64_t out_mesh_axis = out_dist_attr.dims_mapping()[i]; + int64_t out_mesh_axis = out_dist_attr_orig.dims_mapping()[i]; if (out_mesh_axis != -1) { VLOG(3) << "Step4: out_mesh axis " << out_mesh_axis; // 4.1 Calculate the dist_attr after this transform diff --git a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc index b8e355e689caea..01824dd93bca19 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard_function.cc @@ -30,7 +30,7 @@ std::shared_ptr ReshardFunction::Eval( } void ReshardFunction::SetValue(DistTensor* tensor, const DenseTensor& value) { - tensor->value_ = value; + tensor->value_ = std::make_shared(value); } void ReshardFunction::SetDistProps(DistTensor* tensor, @@ -56,7 +56,7 @@ void ReshardFunction::SetDistProps(DistTensor* tensor, } DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) { - return &tensor->value_; + return tensor->value_.get(); } ReshardFunction* ChooseProperReshardFunction( diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc index 29aa1256e01937..3aafe1dc7fbeea 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.cc @@ -53,7 +53,7 @@ void SToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, const auto& in_process_ids = in_process_mesh.process_ids(); auto dtype = in.dtype(); const auto& logical_ddim = in.dims(); - int64_t nranks = in_process_ids.size(); + int64_t nranks = static_cast(in_process_ids.size()); int in_split_axis = GetSplitAxisWithDimsMapping(in.dist_attr().dims_mapping()).begin()->first; int out_split_axis = diff --git a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc index a6f49268c5612f..ea32163d67f624 100644 --- a/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.cc @@ -64,6 +64,7 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, const DistTensor& in, const TensorDistAttr& out_dist_attr, DistTensor* out) { + VLOG(3) << "Call SameStatusReshardFunction Eval"; const auto& in_dist_attr = in.dist_attr(); const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_ids = in_process_mesh.process_ids(); @@ -89,8 +90,8 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, for (const auto& iter : p2p_pair) { int64_t src = iter.first; int64_t dst = iter.second; - VLOG(3) << "Send/Recv from src " << src << " to dst " << dst; if (src == cur_global_rank) { + VLOG(3) << "Send from src " << src << " to dst " << dst; int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst); // Sice send kernel only has input, so we don't need to infermeta // actually. According to this reason, just use the kernel directly. @@ -102,6 +103,7 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, dst_local_rank, dynamic_shape); } else if (dst == cur_global_rank) { + VLOG(3) << "Recv from src " << src << " to dst " << dst; int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src); RESHARD_FUNCTOR_WITH_COMM(dev_ctx, PRecv, diff --git a/paddle/phi/core/distributed/store/tcp_store.cc b/paddle/phi/core/distributed/store/tcp_store.cc index 9650d051f98fbe..6fbe2aa6761e2c 100644 --- a/paddle/phi/core/distributed/store/tcp_store.cc +++ b/paddle/phi/core/distributed/store/tcp_store.cc @@ -421,7 +421,7 @@ std::vector TCPStore::get(const std::string& key) { } void TCPStore::wait(const std::string& key) { - ReplyType reply; + ReplyType reply; // NOLINT VLOG(7) << "TCPStore wait."; _client->send_command_for_key(Command::WAIT, _key_prefix + key); reply = _client->receive_value(); diff --git a/paddle/phi/core/distributed/store/tcp_utils.cc b/paddle/phi/core/distributed/store/tcp_utils.cc index aaf00cb8000853..64c5424928b9ff 100644 --- a/paddle/phi/core/distributed/store/tcp_utils.cc +++ b/paddle/phi/core/distributed/store/tcp_utils.cc @@ -44,7 +44,7 @@ ::addrinfo* get_addr_info(const std::string host, const std::string port, int ai_flags, int family) { - ::addrinfo hints{}, *res; + ::addrinfo hints{}, *res = nullptr; hints.ai_flags = ai_flags; hints.ai_family = family; hints.ai_socktype = SOCK_STREAM; @@ -52,7 +52,7 @@ ::addrinfo* get_addr_info(const std::string host, const char* node = host.empty() ? nullptr : host.c_str(); const char* port_cstr = port.empty() ? nullptr : port.c_str(); - int n; + int n = 0; n = ::getaddrinfo(node, port_cstr, &hints, &res); const char* gai_err = ::gai_strerror(n); const char* proto = (family == AF_INET ? "IPv4" @@ -216,7 +216,7 @@ void send_string(SocketType socket, const std::string& s) { } std::string receive_string(SocketType socket) { - std::string::size_type size; + std::string::size_type size = 0; receive_bytes(socket, &size, 1); std::vector v(size); receive_bytes(socket, v.data(), size); diff --git a/paddle/phi/core/extended_tensor.h b/paddle/phi/core/extended_tensor.h index d02dbabde179fe..73cae43c0b54c0 100644 --- a/paddle/phi/core/extended_tensor.h +++ b/paddle/phi/core/extended_tensor.h @@ -18,12 +18,14 @@ limitations under the License. */ #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/utils/test_macros.h" + namespace phi { /// \brief The ExtendedTensor is a interface for custom designed class. /// If you want to pass some self-designed data as input/output to kernels, /// you can inherit from this class to store your self-designed data. -class ExtendedTensor : public TensorBase { +class TEST_API ExtendedTensor : public TensorBase { public: ExtendedTensor() = default; virtual ~ExtendedTensor() = default; diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index 19e707c40bc551..c7a0a81c7fb4f4 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1333,18 +1333,20 @@ PHI_DEFINE_EXPORTED_int32( "Multiple of the CUPTI device buffer size. If the timestamps have " "been dropped when you are profiling, try increasing this value."); +PHI_DEFINE_EXPORTED_bool(print_ir, false, "Whether print ir debug str."); + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) /** * Communication library related FLAG * Name: FLAGS_dynamic_static_unified_comm * Since Version: 2.5 - * Value Range: bool, default=false + * Value Range: bool, default=true * Example: * Note: Whether to use new communication library in auto parallel and static * mode. If true, it will use unified CommContextManager for communication. */ PHI_DEFINE_EXPORTED_bool(dynamic_static_unified_comm, - false, + true, "Whether to use new communication library in auto " "parallel and static mode."); #endif // FLAGS_dynamic_static_unified_comm diff --git a/paddle/phi/core/generator.cc b/paddle/phi/core/generator.cc index 8cdbb290ea34f8..4541b81de4630a 100644 --- a/paddle/phi/core/generator.cc +++ b/paddle/phi/core/generator.cc @@ -242,7 +242,7 @@ uint64_t Generator::GetCurrentSeed() { uint64_t Generator::Seed() { std::lock_guard lock(this->mu_); - uint64_t seed; + uint64_t seed = 0; std::random_device de; seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF; this->state_.current_seed = seed; diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index adbcb8574518ba..18f3042bbf9c28 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -16,9 +16,7 @@ limitations under the License. */ namespace phi { -void InferMetaContext::SetMetaConfig(MetaConfig config) { - config_ = std::move(config); -} +void InferMetaContext::SetMetaConfig(MetaConfig config) { config_ = config; } void InferMetaContext::EmplaceBackInput(MetaTensor input) { int index = static_cast(inputs_.size()); @@ -96,7 +94,7 @@ InferMetaContext::OptionalInputsBetween(size_t start, size_t end) const { result.emplace_back(in.initialized() ? &in : nullptr); } - return paddle::optional>(std::move(result)); + return paddle::optional>(result); } return paddle::none; } diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index f9c1dca46b2fb5..69c7900def16ba 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -526,7 +526,7 @@ std::string KernelSelectionErrorMessage(const std::string& kernel_name, std::unordered_set dtype_set; // Record all kernel information of kernel_name - for (auto iter : KernelFactory::Instance().kernels()[kernel_name]) { + for (auto const& iter : KernelFactory::Instance().kernels()[kernel_name]) { KernelKey kernel_key = iter.first; if (kernel_key.backend() == target_key.backend()) { support_backend = true; diff --git a/paddle/phi/core/tensor_array.cc b/paddle/phi/core/tensor_array.cc index 8c717e151a1299..11a240596be73c 100644 --- a/paddle/phi/core/tensor_array.cc +++ b/paddle/phi/core/tensor_array.cc @@ -27,7 +27,7 @@ bool TensorArray::initialized() const { return false; } - for (auto tensor : tensors_) { + for (auto const& tensor : tensors_) { if (!tensor.initialized()) { return false; } diff --git a/paddle/phi/core/utils/type_info.h b/paddle/phi/core/utils/type_info.h index b4d908e2c1d9c0..9e31343ed04a42 100644 --- a/paddle/phi/core/utils/type_info.h +++ b/paddle/phi/core/utils/type_info.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/utils/test_macros.h" namespace phi { @@ -40,7 +41,7 @@ class TypeInfo { }; template -class TypeInfoTraits { +class TEST_API TypeInfoTraits { public: static const TypeInfo kType; TypeInfoTraits(); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a3028027ebdd89..2aa8543eb82c32 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1170,7 +1170,7 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { void ElementwiseInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { - return ElementwiseRawInferMeta(x, y, -1, std::move(out)); + return ElementwiseRawInferMeta(x, y, -1, out); } void ElementwiseRawInferMeta(const MetaTensor& x, @@ -1435,7 +1435,7 @@ void FusedMatmulInferMeta(const MetaTensor& x, y_broadcasted = true; } - size_t M, N; + size_t M = 0, N = 0; if (transpose_x) { M = dims_x[ndims_x - 1]; } else { @@ -2136,7 +2136,7 @@ void MatmulInferMeta(const MetaTensor& x, y_broadcasted = true; } - size_t M, N; + size_t M = 0, N = 0; if (trans_x) { M = dims_x[ndims_x - 1]; } else { @@ -3028,7 +3028,7 @@ void YoloBoxInferMeta(const MetaTensor& x, "But received class_num (%s)", class_num)); - int box_num; + int box_num = 0; if ((dim_x[2] > 0 && dim_x[3] > 0) || config.is_runtime) { box_num = static_cast(dim_x[2] * dim_x[3] * anchor_num); } else { @@ -3103,7 +3103,7 @@ void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { y_broadcasted = true; } - size_t M, N; + size_t M = 0, N = 0; if (trans_x) { M = x_dims_vec[x_dims_n - 1]; } else { diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index d047670a9ee5fa..0aca25103f80a7 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -485,10 +485,10 @@ void FusedAttentionInferMeta(const MetaTensor& x, "(dim_embed, 3 * dim_embed).")); } else { // compute the mp nranks - nranks = (y_dim[0] * 3) / y_dim[1]; + nranks = static_cast((y_dim[0] * 3) / y_dim[1]); } - dim_head = y_dim[0] / (num_heads * nranks); - hidden_size = y_dim[0]; + dim_head = static_cast(y_dim[0] / (num_heads * nranks)); + hidden_size = static_cast(y_dim[0]); } else { PADDLE_ENFORCE_EQ(y_dim.size(), 4, @@ -512,9 +512,9 @@ void FusedAttentionInferMeta(const MetaTensor& x, "and must satisfy the limitations: " "(num_head * dim_head == dim_embed)")); } - num_heads = y_dim[1]; - dim_head = y_dim[2]; - hidden_size = y_dim[3]; + num_heads = static_cast(y_dim[1]); + dim_head = static_cast(y_dim[2]); + hidden_size = static_cast(y_dim[3]); } PADDLE_ENFORCE_EQ( @@ -1050,8 +1050,8 @@ void FusedGemmEpilogueInferMeta(const MetaTensor& x, auto x_mat_dims = phi::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1); - int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1]; - int K_from_y = trans_y ? y_dims[1] : y_dims[0]; + int K_from_x = static_cast(trans_x ? x_mat_dims[0] : x_mat_dims[1]); + int K_from_y = static_cast(trans_y ? y_dims[1] : y_dims[0]); PADDLE_ENFORCE_EQ( K_from_x, @@ -1086,7 +1086,7 @@ void FusedGemmEpilogueInferMeta(const MetaTensor& x, "The ReserveSpace would not be used when activation = \"none\"")); } else { int min_size_of_n = activation == "relu" ? 128 : 8; - int N_size = trans_y ? y_dims[0] : y_dims[1]; + int N_size = static_cast(trans_y ? y_dims[0] : y_dims[1]); PADDLE_ENFORCE_EQ(N_size % min_size_of_n, 0, phi::errors::InvalidArgument( diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 8f78755486a7d2..0cd5534a9c44ab 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1997,7 +1997,7 @@ static void Interpolate1DInferShapeCheck( return; } - int out_w_tmp; + int out_w_tmp = 0; if (scale_tensor) { auto scale_tensor_dim = scale_tensor.dims(); PADDLE_ENFORCE_EQ( @@ -2130,7 +2130,7 @@ static void Interpolate2DInferShapeCheck( return; } - int out_h_tmp, out_w_tmp; + int out_h_tmp = 0, out_w_tmp = 0; if (scale_tensor) { auto scale_tensor_dim = scale_tensor.dims(); @@ -2282,7 +2282,7 @@ static void Interpolate3DInferShapeCheck( return; } - int out_d_tmp, out_h_tmp, out_w_tmp; + int out_d_tmp = 0, out_h_tmp = 0, out_w_tmp = 0; if (scale_tensor) { auto scale_tensor_dim = scale_tensor.dims(); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 1c57e2fae92ac0..0e3ac3fb5ca2c8 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -74,7 +74,7 @@ void EyeInferMeta(const Scalar& num_rows, DataType dtype, MetaTensor* out, MetaConfig config) { - int64_t rows, columns; + int64_t rows = 0, columns = 0; if (!config.is_runtime && num_rows.FromTensor()) { rows = -1; } else { diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc index 24d6bed03c52d0..3a9e422320210f 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.cc +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -309,6 +309,11 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr}}; } +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad) { + return {{out_grad.dist_attr(), out_grad.dist_attr()}, {out_grad.dist_attr()}}; +} + SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h index 736aeec35ed0a0..188e557e6737b0 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.h +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -27,6 +27,9 @@ SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x); SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out); +SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad); + SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y); diff --git a/paddle/phi/infermeta/spmd_rules/flatten.cc b/paddle/phi/infermeta/spmd_rules/flatten.cc new file mode 100644 index 00000000000000..0a9c4111d8e7fa --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/flatten.cc @@ -0,0 +1,203 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/flatten.h" +#include + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/dim_trans.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +int PreprocessAxis(int axis, int ndim) { + if (axis < 0) { + axis += ndim; + } + + PADDLE_ENFORCE_LT( + axis, + ndim, + phi::errors::InvalidArgument("The Start_axis or Stop_axis [%d] is not " + "less than the Tensor X's rank [%d].", + axis, + ndim)); + + return axis; +} + +std::vector MakeFlattenDimTrans( + const std::vector& src_shape, int start_axis, int stop_axis) { + std::vector ret; + + std::vector input_dims; + for (int64_t i = 0; i < static_cast(src_shape.size()); i++) { + if (i < start_axis || i > stop_axis) { + ret.emplace_back(new InputDim(i)); + } else { + input_dims.emplace_back(new InputDim(i)); + } + + if (i == stop_axis) { + ret.emplace_back(make_flatten(input_dims)); + } + } + + return ret; +} + +std::vector MakeFlattenDimTransReverse( + const std::vector& src_shape, int start_axis, int stop_axis) { + std::vector ret; + + std::vector tgt_splitted_shape; + for (int i = start_axis; i <= stop_axis; i++) { + tgt_splitted_shape.emplace_back(src_shape[i]); + } + + for (int64_t i = 0; i < static_cast(src_shape.size()); i++) { + if (i < start_axis) { + ret.emplace_back(new InputDim(i)); + } else if (i > stop_axis) { + ret.emplace_back(new InputDim(i - (stop_axis - start_axis))); + } else { + ret.emplace_back(make_split( + new InputDim(start_axis), tgt_splitted_shape, i - start_axis)); + } + } + + return ret; +} + +SpmdInfo FlattenInferSpmd(const DistMetaTensor& x, + int start_axis, + int stop_axis) { + // Step0: Verify input args based on flatten logic + auto src_shape = phi::vectorize(x.dims()); + int x_ndim = static_cast(src_shape.size()); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping.size())); + + // Step1: Build the transformation from + // the original shape to the target shape + + start_axis = PreprocessAxis(start_axis, x_ndim); + stop_axis = PreprocessAxis(stop_axis, x_ndim); + std::vector trans = + MakeFlattenDimTrans(src_shape, start_axis, stop_axis); + + // Step2: Infer the dims mapping of input (if reshard is + // needed) and output from the dimension transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(x, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping. + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr out_dist_attr(x_dist_attr_src); + out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "FlattenInferSpmd: X shape: [" << str_join(src_shape) << "]"; + VLOG(4) << "Start_axis: " << start_axis; + VLOG(4) << "Stop_axis: " << start_axis; + VLOG(4) << "Transformation from input to output:"; + for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping) + << "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int start_axis, + int stop_axis) { + // Step0: Verify input args based on flatten logic + auto x_shape = phi::vectorize(x.dims()); + auto x_ndim = x_shape.size(); + auto out_shape = phi::vectorize(out.dims()); + int out_ndim = out_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + + // Step1: Build the transformation from the output shape + // to original shape. This function infers the dims mapping + // from output to input, we first get the transformation + // from output to input so that we can infer the dims mapping + // with the map from output axes to input axes. + + start_axis = PreprocessAxis(start_axis, x_ndim); + stop_axis = PreprocessAxis(stop_axis, x_ndim); + + std::vector trans = + MakeFlattenDimTransReverse(x_shape, start_axis, stop_axis); + + // Step2: Infer the dims mapping of input with + // output's dims_mapping and the transformation. + std::vector> dims_mapping_vec = + InferFromDimTrans(out, trans); + + // Step3: Update the dist attributes of input + // and output with the inferred dims mapping + TensorDistAttr out_dist_attr_dst(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]); + TensorDistAttr x_dist_attr(x.dist_attr()); + x_dist_attr.set_dims_mapping(dims_mapping_vec[1]); + + VLOG(4) << "FlattenInferSpmdReverse: Out shape: [" << str_join(out_shape) + << "] X shape: [" << str_join(x_shape) << "]"; + VLOG(4) << "Transformation from output to input:"; + for (int64_t i = 0, n = trans.size(); i < n; i++) { + DimTrans* t = trans[i]; + VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); + } + VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] " + << "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]"; + VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n"; + + CleanUp(); + + return {{x_dist_attr}, {out_dist_attr_dst}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/flatten.h b/paddle/phi/infermeta/spmd_rules/flatten.h new file mode 100644 index 00000000000000..bb62d8c0d7b0a2 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/flatten.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo FlattenInferSpmd(const DistMetaTensor& x, + int start_axis, + int stop_axis); + +SpmdInfo FlattenInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int start_axis, + int stop_axis); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 98c2ebd7493b91..4893c7071f19e4 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -278,6 +278,14 @@ SpmdInfo MatmulInferSpmdReverse(const DistMetaTensor& x, return {{x_dist_attr_dst, y_dist_attr_dst}, {out_dist_attr_src}}; } +static bool DistAttrsAreBasicallyEqual( + const phi::distributed::TensorDistAttr& in_dist_attr, + const phi::distributed::TensorDistAttr& out_dist_attr) { + return (in_dist_attr.process_mesh() == out_dist_attr.process_mesh() && + in_dist_attr.dims_mapping() == out_dist_attr.dims_mapping() && + in_dist_attr.partial_status() == out_dist_attr.partial_status()); +} + SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const DistMetaTensor& out_grad, @@ -287,8 +295,8 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y, const char* debug_msg) { PADDLE_ENFORCE_EQ( - x_dist_attr, - y.dist_attr(), + DistAttrsAreBasicallyEqual(x_dist_attr, y.dist_attr()), + true, phi::errors::Unavailable("The matmul grad infer spmd `%s` verify " "error: left dist attr is %s, " "right dist attr is %s.", diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index dd89de7229b9a2..eda61be1f22846 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/embedding.h" +#include "paddle/phi/infermeta/spmd_rules/flatten.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/reduction.h" @@ -492,6 +493,11 @@ PD_REGISTER_SPMD_RULE(reshape2, PD_INFER_SPMD(phi::distributed::ReshapeInferSpmd), PD_INFER_SPMD(phi::distributed::ReshapeInferSpmdReverse)); +// flatten rule +PD_REGISTER_SPMD_RULE(flatten, + PD_INFER_SPMD(phi::distributed::FlattenInferSpmd), + PD_INFER_SPMD(phi::distributed::FlattenInferSpmdReverse)); + // embedding rule PD_REGISTER_SPMD_RULE( embedding, diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index b18ecf48363f04..d97a16e57fa614 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -255,6 +255,32 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, output_box->set_dtype(target_box.dtype()); } +void DpsgdInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + float clip, + float batch_size, + float sigma, + int size, + MetaTensor* param_out) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ(phi::product(lr_dims), + 1, + phi::errors::InvalidArgument( + "Learning rate should have 1 dimension. But Received " + "LearningRate's dims [%s].", + phi::product(lr_dims))); + auto param_dims = param.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + grad.dims(), + phi::errors::InvalidArgument( + "Param and Grad input of DpsgdOp should have same dimension. But " + "received Para's dim [%s] and Grad's dim [%s].", + param_dims, + grad.dims())); + param_out->set_dims(param_dims); +} void FlashAttnInferMeta(const MetaTensor& q, const MetaTensor& k, const MetaTensor& v, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index 47c4b9826da4a8..797835a1abd511 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -63,6 +63,15 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, MetaTensor* output_box, MetaConfig config = MetaConfig()); +void DpsgdInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + float clip, + float batch_size, + float sigma, + int size, + MetaTensor* param_out); + void FlashAttnInferMeta(const MetaTensor& q, const MetaTensor& k, const MetaTensor& v, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 6eaff66c583898..1dd9355549c021 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -148,18 +148,19 @@ void ArgMinMaxInferMeta(const MetaTensor& x, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, MetaTensor* out, MetaConfig config) { PADDLE_ENFORCE_EQ( - (dtype < 0 || dtype == 2 || dtype == 3), + (dtype == DataType::UNDEFINED || dtype == DataType::INT32 || + dtype == DataType::INT64), true, phi::errors::InvalidArgument( "The attribute of dtype in argmin/argmax must be [%s] or [%s], but " "received [%s]", DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT64), - DataTypeToString(phi::TransToPhiDataType(dtype)))); + DataTypeToString(dtype))); if (!config.is_runtime && axis.FromTensor()) { std::vector vec; @@ -177,10 +178,8 @@ void ArgMinMaxInferMeta(const MetaTensor& x, } } out->set_dims(phi::make_ddim(vec)); - if (dtype == 2) { - out->set_dtype(DataType::INT32); - } else if (dtype == 3) { - out->set_dtype(DataType::INT64); + if (dtype == DataType::INT32 || dtype == DataType::INT64) { + out->set_dtype(dtype); } return; } @@ -216,7 +215,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, if (int_axis < 0) int_axis += x_rank; if (config.is_runtime) { - if (dtype == phi::TransToProtoVarType(DataType::INT32)) { + if (dtype == DataType::INT32) { int64_t all_element_num = 0; if (flatten) { all_element_num = phi::product(x_dims); @@ -253,10 +252,8 @@ void ArgMinMaxInferMeta(const MetaTensor& x, } out->set_dims(phi::make_ddim(vec)); - if (dtype == 2) { - out->set_dtype(DataType::INT32); - } else if (dtype == 3) { - out->set_dtype(DataType::INT64); + if (dtype == DataType::INT32 || dtype == DataType::INT64) { + out->set_dtype(dtype); } } @@ -427,6 +424,14 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { out->set_dtype(x.dtype()); } +void CINNBroadcastInferMeta(const MetaTensor& x, + const std::vector& axes, + const std::vector& out_shape, + MetaTensor* out) { + out->set_dims(phi::make_ddim(out_shape)); + out->set_dtype(x.dtype()); +} + void ClassCenterSampleInferMeta(const MetaTensor& label, int num_classes, int num_samples, @@ -555,7 +560,7 @@ void CumWithIndicesInferMeta(const MetaTensor& x, phi::errors::InvalidArgument("dtype of indices must be int32 or int64")); if (indices_type == DataType::INT32) { - int _axis; + int _axis = 0; if (axis < 0) { _axis = axis + x_dims.size(); } else { @@ -1682,11 +1687,11 @@ void FrameInferMeta(const MetaTensor& x, "Attribute(axis) of FrameOp should 0 or -1, but got %s.", axis)); std::vector output_shape; - int seq_length; - int n_frames; + int seq_length = 0; + int n_frames = 0; - int start_axis; - int end_axis; + int start_axis = 0; + int end_axis = 0; if (axis == 0) { seq_length = static_cast(x_dims[0]); @@ -2566,12 +2571,12 @@ void OverlapAddInferMeta(const MetaTensor& x, "Attribute(axis) of OverlapAddOp should 0 or -1, but got %s.", axis)); std::vector output_shape; - int n_frames; - int frame_length; - int seq_length; + int n_frames = 0; + int frame_length = 0; + int seq_length = 0; - int start_axis; - int end_axis; + int start_axis = 0; + int end_axis = 0; if (axis == 0) { n_frames = static_cast(x_dims[0]); frame_length = static_cast(x_dims[1]); @@ -3143,8 +3148,8 @@ void QrInferMeta(const MetaTensor& x, x_dims.size(), 2, phi::errors::InvalidArgument("the rank of input must greater than 2")); - bool compute_q; - bool reduced_mode; + bool compute_q = false; + bool reduced_mode = false; int m = static_cast(x_dims[x_rank - 2]); int n = static_cast(x_dims[x_rank - 1]); int min_mn = std::min(m, n); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a3b7e87d86d0bf..d79b53a71097e4 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -49,7 +49,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, MetaTensor* out, MetaConfig config = MetaConfig()); @@ -89,6 +89,11 @@ void CheckNumericsInferMeta(const MetaTensor& tensor, void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out); +void CINNBroadcastInferMeta(const MetaTensor& x, + const std::vector& axes, + const std::vector& out_shape, + MetaTensor* output); + void ClassCenterSampleInferMeta(const MetaTensor& label, int num_classes, int num_samples, diff --git a/paddle/phi/kernels/arg_min_max_kernel.h b/paddle/phi/kernels/arg_min_max_kernel.h index 258c8f21e0540b..5f1b4fc934fec2 100644 --- a/paddle/phi/kernels/arg_min_max_kernel.h +++ b/paddle/phi/kernels/arg_min_max_kernel.h @@ -25,7 +25,7 @@ void ArgMinKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out); template @@ -34,7 +34,7 @@ void ArgMaxKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 6816a353ce5042..65bde5601128f8 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -307,7 +307,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, + SoftplusGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad, ReluDoubleGradKernel) @@ -320,8 +321,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad, - SoftplusDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_KERNEL(tanh_triple_grad, CPU, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 813a7ffc7ba422..a8169df1021d2b 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -201,7 +201,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) PD_REGISTER_KERNEL(exp, CPU, diff --git a/paddle/phi/kernels/cpu/allclose_kernel.cc b/paddle/phi/kernels/cpu/allclose_kernel.cc index c6a512aa95cb18..fd6cf3aebc2687 100644 --- a/paddle/phi/kernels/cpu/allclose_kernel.cc +++ b/paddle/phi/kernels/cpu/allclose_kernel.cc @@ -30,7 +30,7 @@ void AllCloseKernel(const Context& dev_ctx, const Scalar& atol, bool equal_nan, DenseTensor* out) { - double rtol_v, atol_v; + double rtol_v = NAN, atol_v = NAN; if (rtol.dtype() == DataType::FLOAT64) { rtol_v = rtol.to(); } else if (rtol.dtype() == DataType::FLOAT32) { @@ -58,7 +58,7 @@ void AllCloseKernel(const Context& dev_ctx, auto num = x.numel(); for (int64_t i = 0; i < num; ++i) { const T a = in_a[i], b = in_b[i]; - bool val; + bool val = false; if (std::isnan(a) || std::isnan(b)) { val = equal_nan && std::isnan(a) == std::isnan(b); } else { diff --git a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc index 20dfd2faff8a42..ce00926101f2cc 100644 --- a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc @@ -151,9 +151,9 @@ void ArgMinMaxKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { - if (dtype < 0) { + if (dtype == DataType::UNDEFINED) { phi::VisitDataTypeTiny( phi::DataType::INT64, VisitDataArgMinMaxFunctor( @@ -161,7 +161,7 @@ void ArgMinMaxKernel(const Context& dev_ctx, return; } phi::VisitDataTypeTiny( - phi::TransToPhiDataType(dtype), + dtype, VisitDataArgMinMaxFunctor( dev_ctx, x, axis.to(), keepdims, flatten, out)); } @@ -172,7 +172,7 @@ void ArgMinKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { ArgMinMaxKernel( dev_ctx, x, axis, keepdims, flatten, dtype, out); @@ -184,7 +184,7 @@ void ArgMaxKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { ArgMinMaxKernel( dev_ctx, x, axis, keepdims, flatten, dtype, out); diff --git a/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc b/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc index 7a95e47047a103..071140a2a54200 100644 --- a/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cumprod_grad_kernel.cc @@ -51,8 +51,8 @@ void CumprodGradKernel(const Context& dev_ctx, size_t numel = outer_dim * mid_dim * inner_dim; // deal with complex - const T* x_data_deal; - const T* out_data_deal; + const T* x_data_deal = nullptr; + const T* out_data_deal = nullptr; Allocator::AllocationPtr x_conj; Allocator::AllocationPtr out_conj; if (phi::IsComplexType(x.dtype())) { diff --git a/paddle/phi/kernels/cpu/diag_kernel.cc b/paddle/phi/kernels/cpu/diag_kernel.cc index 1576d80b15206b..fb15fcbe61f7e6 100644 --- a/paddle/phi/kernels/cpu/diag_kernel.cc +++ b/paddle/phi/kernels/cpu/diag_kernel.cc @@ -32,7 +32,7 @@ void DiagKernel(const Context& dev_ctx, T* out_data = dev_ctx.template Alloc(out); auto out_dims = out->dims(); - int64_t i; + int64_t i = 0; if (x_dims.size() <= 1) { phi::funcs::SetConstant set_padding_value; set_padding_value(dev_ctx, out, static_cast(padding_value)); diff --git a/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc b/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc index b8156459f2a923..aabca4c852e04b 100644 --- a/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc +++ b/paddle/phi/kernels/cpu/distribute_fpn_proposals_kernel.cc @@ -46,7 +46,7 @@ void DistributeFpnProposalsKernel( } std::vector fpn_rois_lod; - int fpn_rois_num; + int fpn_rois_num = 0; if (rois_num.get_ptr()) { fpn_rois_lod = funcs::GetLodFromRoisNum(dev_ctx, rois_num.get_ptr()); } else { diff --git a/paddle/phi/kernels/cpu/eigvals_kernel.cc b/paddle/phi/kernels/cpu/eigvals_kernel.cc index b0fc48db5739c2..cd4aaca2ecf83f 100644 --- a/paddle/phi/kernels/cpu/eigvals_kernel.cc +++ b/paddle/phi/kernels/cpu/eigvals_kernel.cc @@ -216,7 +216,7 @@ void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { // query workspace size T qwork; - int info; + int info = 0; funcs::lapackEig>('N', 'N', static_cast(n_dim), diff --git a/paddle/phi/kernels/cpu/group_norm_kernel.cc b/paddle/phi/kernels/cpu/group_norm_kernel.cc index a041c855346756..35975018dca1cc 100644 --- a/paddle/phi/kernels/cpu/group_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/group_norm_kernel.cc @@ -91,7 +91,7 @@ void GroupNormKernel(const Context& dev_ctx, if (data_layout == DataLayout::kNCHW) { for (int cid = 0; cid < number; cid++) { - int imid; + int imid = 0; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M) { // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used @@ -128,7 +128,7 @@ void GroupNormKernel(const Context& dev_ctx, } else { for (int cid = 0; cid < number; cid++) { iter_x_data = tmp_x + cid; - int imid; + int imid = 0; for (imid = 0; imid < imsize - (imsize % M); imid += M, iter_x_data += M * C) { // TODO(gaoxiang): Because AVX/AVX2/AVX512 can not directly used diff --git a/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc index 7d5f60731f13dc..14937ea613936b 100644 --- a/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/instance_norm_grad_kernel.cc @@ -170,7 +170,7 @@ void InstanceNormDoubleGradKernel(const Context& dev_ctx, const auto* ddBias = ddbias.get_ptr(); phi::funcs::SetConstant set_constant; const auto& x_dims = x.dims(); - int N, C, H, W, D; + int N = 0, C = 0, H = 0, W = 0, D = 0; funcs::ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); const int sample_size = static_cast(x.numel() / N / C); const int NxC = N * C; diff --git a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc index f1478d5e3b3e7e..e32738b4588c83 100644 --- a/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_grad_kernel.cc @@ -407,7 +407,7 @@ static void Interpolate1DCPUBwd( int align_mode, DenseTensor* input_grad) { const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); - int n, c, in_d, in_h, in_w; + int n = 0, c = 0, in_d = 0, in_h = 0, in_w = 0; funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); float scale_w = -1.0; @@ -508,7 +508,7 @@ static void Interpolate2DCPUBwd( int align_mode, DenseTensor* input_grad) { const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); - int n, c, in_d, in_h, in_w; + int n = 0, c = 0, in_d = 0, in_h = 0, in_w = 0; funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); float scale_h = -1; @@ -674,7 +674,7 @@ static void Interpolate3DCPUBwd( int align_mode, DenseTensor* input_grad) { const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); - int n, c, in_d, in_h, in_w; + int n = 0, c = 0, in_d = 0, in_h = 0, in_w = 0; funcs::ExtractNCDWH(input.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); float scale_d = -1; diff --git a/paddle/phi/kernels/cpu/interpolate_kernel.cc b/paddle/phi/kernels/cpu/interpolate_kernel.cc index 198cba7d1e9488..7c957657ceb39e 100644 --- a/paddle/phi/kernels/cpu/interpolate_kernel.cc +++ b/paddle/phi/kernels/cpu/interpolate_kernel.cc @@ -561,7 +561,7 @@ static void Interpolate1DCPUFwd( int align_mode, DenseTensor* output) { const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); - int n, c, in_d, in_h, in_w; + int n = 0, c = 0, in_d = 0, in_h = 0, in_w = 0; funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); float scale_w = -1.; @@ -662,7 +662,7 @@ static void Interpolate2DCPUFwd( int align_mode, DenseTensor* output) { const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); - int n, c, in_d, in_h, in_w; + int n = 0, c = 0, in_d = 0, in_h = 0, in_w = 0; funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); float scale_h = -1; @@ -833,7 +833,7 @@ static void Interpolate3DCPUFwd( int align_mode, DenseTensor* output) { const DataLayout data_layout = phi::StringToDataLayout(data_layout_str); - int n, c, in_d, in_h, in_w; + int n = 0, c = 0, in_d = 0, in_h = 0, in_w = 0; funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); float scale_d = -1; diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc index 3d21c49ee1e2bc..0713725127190a 100644 --- a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -39,7 +39,7 @@ void LapackSVD(const T* x_data, T* eigenvalues_data, int rows, int cols) { int lwork = 3 * mn + std::max(mx, 7 * mn); std::vector work(lwork); std::vector iwork(8 * mn); - int info; + int info = 0; phi::funcs::lapackSvd(jobz, rows, diff --git a/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc b/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc index 336af33d8679b6..aa04288124a9b7 100644 --- a/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc +++ b/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc @@ -381,7 +381,7 @@ void MultiClassNMS(const Context& ctx, *num_nmsed_out = num_det; const T* scores_data = scores.data(); if (keep_top_k > -1 && num_det > keep_top_k) { - const T* sdata; + const T* sdata = nullptr; std::vector>> score_index_pairs; for (const auto& it : *indices) { int label = it.first; @@ -441,7 +441,7 @@ void MultiClassOutput(const Context& ctx, auto* scores_data = scores.data(); auto* bboxes_data = bboxes.data(); auto* odata = out->data(); - const T* sdata; + const T* sdata = nullptr; DenseTensor bbox; bbox.Resize({scores.dims()[0], box_size}); int count = 0; @@ -456,7 +456,7 @@ void MultiClassOutput(const Context& ctx, for (auto idx : indices) { odata[count * out_dim] = label; // label - const T* bdata; + const T* bdata = nullptr; if (scores_size == 3) { bdata = bboxes_data + idx * box_size; odata[count * out_dim + 1] = sdata[idx]; // score diff --git a/paddle/phi/kernels/cpu/norm_grad_kernel.cc b/paddle/phi/kernels/cpu/norm_grad_kernel.cc index 6d51a64c76bb1c..8bc46fa6cdffc6 100644 --- a/paddle/phi/kernels/cpu/norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/norm_grad_kernel.cc @@ -40,7 +40,7 @@ void NormGradKernel(const Context& ctx, auto xdim = in_x->dims(); if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; + int pre = 0, n = 0, post = 0; funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post); auto* place = ctx.eigen_device(); diff --git a/paddle/phi/kernels/cpu/norm_kernel.cc b/paddle/phi/kernels/cpu/norm_kernel.cc index 21af086515d71c..73540f83605920 100644 --- a/paddle/phi/kernels/cpu/norm_kernel.cc +++ b/paddle/phi/kernels/cpu/norm_kernel.cc @@ -33,10 +33,10 @@ void NormKernel(const Context& ctx, auto xdim = x.dims(); T eps = epsilon; if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; + int pre = 0, n = 0, post = 0; funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post); - DenseTensor* out_norm; + DenseTensor* out_norm = nullptr; DenseTensor out_norm_tmp; if (is_test) { auto out_dim = x.dims(); diff --git a/paddle/phi/kernels/cpu/p_norm_kernel.cc b/paddle/phi/kernels/cpu/p_norm_kernel.cc index 7a683438176bb9..3a837c96ec58a9 100644 --- a/paddle/phi/kernels/cpu/p_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/p_norm_kernel.cc @@ -58,7 +58,7 @@ void PNormKernel(const Context& dev_ctx, auto xdim = in_x->dims(); if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; + int pre = 0, n = 0, post = 0; GetDims(xdim, axis, &pre, &n, &post, asvector); for (int i = 0; i < xdim.size(); i++) { diff --git a/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc index 17539957a0d443..3a517cfa1fb612 100644 --- a/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/psroi_pool_grad_kernel.cc @@ -43,7 +43,7 @@ void PsroiPoolGradKernel(const Context& ctx, DenseTensor rois_batch_id_list; rois_batch_id_list.Resize({rois_num_t}); int* rois_batch_id_data = ctx.template Alloc(&rois_batch_id_list); - int rois_batch_size; + int rois_batch_size = 0; if (rois_num.get_ptr()) { rois_batch_size = static_cast(rois_num->numel()); auto* rois_num_t_data = rois_num->data(); diff --git a/paddle/phi/kernels/cpu/psroi_pool_kernel.cc b/paddle/phi/kernels/cpu/psroi_pool_kernel.cc index fe48ee9e7e88e3..3b15135133049f 100644 --- a/paddle/phi/kernels/cpu/psroi_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/psroi_pool_kernel.cc @@ -53,7 +53,7 @@ void PsroiPoolKernel(const Context& ctx, rois_batch_id_list.Resize({rois_num_t}); int* rois_batch_id_data = ctx.template Alloc(&rois_batch_id_list); - int rois_batch_size; + int rois_batch_size = 0; if (rois_num.get_ptr()) { rois_batch_size = static_cast(rois_num->numel()); auto* rois_num_data = rois_num->data(); diff --git a/paddle/phi/kernels/cpu/qr_kernel.cc b/paddle/phi/kernels/cpu/qr_kernel.cc index ac61e8e172ae6e..194906ae1dc346 100644 --- a/paddle/phi/kernels/cpu/qr_kernel.cc +++ b/paddle/phi/kernels/cpu/qr_kernel.cc @@ -29,8 +29,8 @@ void QrKernel(const Context& ctx, const std::string& mode, DenseTensor* q, DenseTensor* r) { - bool compute_q; - bool reduced_mode; + bool compute_q = false; + bool reduced_mode = false; std::tie(compute_q, reduced_mode) = phi::funcs::ParseQrMode(mode); auto numel = x.numel(); PADDLE_ENFORCE_GT( diff --git a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc index a8d6723cce6d10..ea098d09a5d562 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc @@ -43,5 +43,7 @@ PD_REGISTER_KERNEL(mean_raw, float, double, bool, + int, + int64_t, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc b/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc index 81868afc46318a..119f4ea1b0ac40 100644 --- a/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/roi_align_grad_kernel.cc @@ -29,7 +29,7 @@ void bilinear_interpolate_gradient(const int height, const T out_grad_this_bin, const T count, T* batch_grad_data) { - int x_low, y_low, x_high, y_high; + int x_low = 0, y_low = 0, x_high = 0, y_high = 0; T w1, w2, w3, w4; if (y < -1.0 || y > height || x < -1.0 || x > width) { w1 = w2 = w3 = w4 = 0; @@ -94,7 +94,7 @@ void RoiAlignGradKernel(const Context& dev_ctx, DenseTensor roi_batch_id_list = Empty(dev_ctx, {rois_num}); int* box_batch_id_data = roi_batch_id_list.data(); - int boxes_batch_size; + int boxes_batch_size = 0; if (boxes_num) { boxes_batch_size = static_cast(boxes_num->numel()); auto* boxes_num_data = boxes_num->data(); diff --git a/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc index 704a2b4b610fcc..e25a581cbd9dd9 100644 --- a/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/roi_pool_grad_kernel.cc @@ -37,7 +37,7 @@ void RoiPoolGradKernel(const Context& dev_ctx, DenseTensor box_batch_id_list = Empty(dev_ctx, {rois_num}); int* box_batch_id_data = box_batch_id_list.data(); - int boxes_batch_size; + int boxes_batch_size = 0; if (boxes_num) { boxes_batch_size = static_cast(boxes_num->numel()); auto* boxes_num_data = boxes_num->data(); diff --git a/paddle/phi/kernels/cpu/svd_kernel.cc b/paddle/phi/kernels/cpu/svd_kernel.cc index 1ae2d9cce0d400..a3f6f38fe47802 100644 --- a/paddle/phi/kernels/cpu/svd_kernel.cc +++ b/paddle/phi/kernels/cpu/svd_kernel.cc @@ -35,7 +35,7 @@ void LapackSvd( int lwork = full ? (4 * mn * mn + 6 * mn + mx) : (4 * mn * mn + 7 * mn); std::vector work(lwork); std::vector iwork(8 * mn); - int info; + int info = 0; phi::funcs::lapackSvd(jobz, rows, cols, diff --git a/paddle/phi/kernels/cpu/yolo_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/yolo_loss_grad_kernel.cc index 75fcf48cd4acf8..c876718d8a8b1c 100644 --- a/paddle/phi/kernels/cpu/yolo_loss_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/yolo_loss_grad_kernel.cc @@ -169,7 +169,7 @@ void YoloLossGradKernel(const Context& dev_ctx, T* input_grad_data = dev_ctx.template Alloc(input_grad); memset(input_grad_data, 0, input_grad->numel() * sizeof(T)); - const T* gt_score_data; + const T* gt_score_data = nullptr; DenseTensor gtscore; if (!(gt_score.is_initialized())) { gtscore.Resize({n, b}); diff --git a/paddle/phi/kernels/cpu/yolo_loss_kernel.cc b/paddle/phi/kernels/cpu/yolo_loss_kernel.cc index 275e83cc9b40fa..280ac791d049bb 100644 --- a/paddle/phi/kernels/cpu/yolo_loss_kernel.cc +++ b/paddle/phi/kernels/cpu/yolo_loss_kernel.cc @@ -229,7 +229,7 @@ void YoloLossKernel(const Context& dev_ctx, gt_match_mask->Resize({n, b}); int* gt_match_mask_data = dev_ctx.template Alloc(gt_match_mask); - const T* gt_score_data; + const T* gt_score_data = nullptr; DenseTensor gtscore; if (!(gt_score.is_initialized())) { gtscore.Resize({n, b}); diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index c50194bfaf009a..b2c2d493c48ad3 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -799,6 +799,31 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct SoftplusGradFunctor> + : public BaseActivationFunctor> { + float beta; + float threshold; + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto x_beta = static_cast>(beta) * x; // NOLINT + dx.device(d) = + (x_beta > static_cast>(threshold)) + .select(dout, + dout / (static_cast>(1) + (-x_beta).exp()) + .unaryExpr(Conj())); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct SoftplusDoubleGradFunctor : public BaseActivationFunctor { float beta; @@ -3681,7 +3706,7 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor { MPType x = static_cast(arg_x); MPType b = static_cast(beta); MPType t = static_cast(threshold); - MPType x_beta = x * beta; + MPType x_beta = x * static_cast(beta); return static_cast(x_beta > t ? x : log(one + exp(x_beta)) / b); } }; @@ -3711,6 +3736,34 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSoftplusGradFunctor> + : public BaseActivationFunctor> { + using MPType = typename phi::dtype::MPTypeTrait>::Type; + MPType one = static_cast(1.0f); + float beta; + float threshold; + + typename BaseActivationFunctor>::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + // dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_dout, const ComplexType arg_x) const { + MPType dout = static_cast(arg_dout); + MPType x = static_cast(arg_x); + MPType b = static_cast(beta); + MPType t = static_cast(threshold); + MPType x_beta = x * static_cast(beta); + return x_beta > t + ? dout + : static_cast>(dout / conj(one + exp(-x_beta))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaAtanhGradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index d98b2f17476948..a1f9c1eb4346cb 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -511,313 +511,6 @@ void LaunchBroadcastKernel( #endif } -#ifndef PADDLE_WITH_XPU_KP -HOSTDEVICE static int64_t ConvertSrcIdxToDstIdx( - int64_t src_idx, - const phi::Array &src_strides, - const phi::Array &dst_strides, - int rank) { - int64_t dst_idx = 0; - int64_t origin_src_idx = src_idx; - for (int k = 0; k < rank; ++k) { - auto local_idx = src_idx / src_strides[k + 1]; - src_idx -= local_idx * src_strides[k + 1]; - - if (dst_strides[k] != dst_strides[k + 1]) { - dst_idx += local_idx * dst_strides[k + 1]; - } - } - return dst_idx; -} - -template -struct ReadVecDataWithInt64Index { - template - static __device__ __forceinline__ void Apply( - const Array1 &in, - ArgsT *args, - int64_t idx, - const Array2 &need_broadcast, - const phi::Array &src_strides, - const Array3 &dst_strides, - int rank, - bool is_boundary) { - using Type = std::tuple_element_t; - if (is_boundary) { -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - std::get(args[i]) = in[Index][ConvertSrcIdxToDstIdx( - idx + i, src_strides, dst_strides[Index], rank)]; - } - } else { - if (!need_broadcast[Index]) { - kps::ReadData( - args, reinterpret_cast(in[Index]) + idx, 1); - } else { -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - std::get(args[i]) = in[Index][ConvertSrcIdxToDstIdx( - idx + i, src_strides, dst_strides[Index], rank)]; - } - } - } - } -}; - -template -__global__ void BroadcastKernelWithInt64Index( - const phi::Array::kValue> - &ins, - OutT *out, - phi::Array, - MaxWithOne::kValue> ins_strides, - phi::Array out_strides, - phi::Array::kValue> need_broadcasts, - int rank, - Functor functor) { - int64_t numel = out_strides[0]; - int64_t idx = - (static_cast(blockIdx.x) * blockDim.x + threadIdx.x) * VecSize; - int64_t stride = static_cast(blockDim.x) * gridDim.x * VecSize; - int64_t limit = numel - VecSize; - - using Traits = phi::funcs::FunctionTraits; - using ArgsT = typename Traits::ArgsTuple; - - ArgsT args[VecSize]; - phi::AlignedVector out_vec; - for (; idx <= limit; idx += stride) { - Unroller::step( - ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, false); - -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - out_vec[i] = static_cast(Apply(functor, args[i])); - } - phi::Store(out_vec, out + idx); - } - - if (idx < numel) { - int remain = numel - idx; // remain is always less than VecSize, therefore - // `int` is enough here - Unroller::step( - ins, args, idx, need_broadcasts, out_strides, ins_strides, rank, true); - for (int i = 0; i < remain; ++i) { - out_vec[idx + i] = static_cast(Apply(functor, args[i])); - } - } -} - -template -struct LaunchBroadcastKernelWithInt64IndexHelper { - static void Run(const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor functor) { - PADDLE_THROW(phi::errors::PermissionDenied( - "Unreachable code branch. This may be a bug.")); - } -}; - -template -struct LaunchBroadcastKernelWithInt64IndexHelper { - static void Run(const KPDevice &ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor functor) { - using Traits = phi::funcs::FunctionTraits; - using ArgsT = typename Traits::ArgsTuple; - ArgsT arg; - phi::Array::kValue> - ins_ptrs; - UnrollerWithoutVecSize::step(ins, arg, &ins_ptrs); - - auto *out_tensor = (*outs)[0]; - auto *out_ptr = ctx.Alloc(out_tensor); - - phi::Array, - MaxWithOne::kValue> - ins_expand_dims; - phi::Array broadcast_out_dims; - int rank; - if (Arity == 1) { - rank = ins[0]->dims().size(); - for (int i = 0; i < rank; ++i) { - broadcast_out_dims[i] = ins[0]->dims()[i]; - } - ins_expand_dims[0] = broadcast_out_dims; - } else if (Arity >= 2) { - CalculateBroadcastDims(ins[0]->dims().Get(), - ins[1]->dims().Get(), - ins[0]->dims().size(), - ins[1]->dims().size(), - axis, - ins_expand_dims[0].GetMutable(), - ins_expand_dims[1].GetMutable(), - broadcast_out_dims.GetMutable(), - &rank); - for (int i = 2; i < Arity; ++i) { - auto tmp_dims = broadcast_out_dims; - phi::Array tmp_expand_dims; - int tmp_rank; - PADDLE_ENFORCE_GE(rank, - ins[i]->dims().size(), - phi::errors::InvalidArgument( - "Unsupported reverse broadcast when the input " - "tensor number is larger than 2.")); - CalculateBroadcastDims(tmp_dims.Get(), - ins[i]->dims().Get(), - rank, - ins[i]->dims().size(), - axis, - tmp_expand_dims.GetMutable(), - ins_expand_dims[i].GetMutable(), - broadcast_out_dims.GetMutable(), - &tmp_rank); - PADDLE_ENFORCE_EQ(rank, - tmp_rank, - phi::errors::InvalidArgument( - "Wrong broadcast algorithm. This may be a bug.")); - } - } - - phi::Array, - MaxWithOne::kValue> - ins_strides; - phi::Array::kValue> need_broadcasts; - phi::Array out_strides; - const auto &out_dims = out_tensor->dims(); - if (rank <= out_dims.size()) { - out_strides = ShapeToStride(out_dims.Get(), rank); - } else { - out_strides = ShapeToStride(broadcast_out_dims.Get(), rank); - } - - for (int i = 0; i < Arity; ++i) { - ins_strides[i] = ShapeToStride(ins_expand_dims[i].Get(), rank); - need_broadcasts[i] = - !IsSameShape(out_strides.Get(), ins_strides[i].Get(), rank + 1); - } - - int64_t numel = out_strides[0]; - auto gpu_config = - phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); - - BroadcastKernelWithInt64Index - <<>>(ins_ptrs, - out_ptr, - ins_strides, - out_strides, - need_broadcasts, - rank, - functor); - } - - private: - static void CalculateBroadcastDims(const int64_t *x_dims, - const int64_t *y_dims, - int nx, - int ny, - int axis, - int64_t *x_out_dims, - int64_t *y_out_dims, - int64_t *broadcast_out_dims, - int *length) { - PADDLE_ENFORCE_GE( - axis, 0, phi::errors::InvalidArgument("Invalid axis value: %d", axis)); - if (nx == ny) { - *length = nx; - for (int i = 0; i < nx; ++i) { - if (x_dims[i] != y_dims[i]) { - PADDLE_ENFORCE_EQ( - x_dims[i] == 1 || y_dims[i] == 1, - true, - phi::errors::InvalidArgument("Cannot broadcast input shape where " - "x_dims[%d] = %d, y_dims[%d] = %d.", - i, - x_dims[i], - i, - y_dims[i])); - } - broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i]); - x_out_dims[i] = x_dims[i]; - y_out_dims[i] = y_dims[i]; - } - } else if (nx > ny) { - *length = nx; - for (int i = nx - axis; i < ny; ++i) { - PADDLE_ENFORCE_EQ( - y_dims[i], - 1, - phi::errors::InvalidArgument( - "The trailing Y.shape[%d] should be 1 but got %d.", - i, - y_dims[i])); - } - - for (int i = 0; i < nx; ++i) { - if (i >= axis && i - axis < ny) { - if (x_dims[i] != y_dims[i - axis]) { - PADDLE_ENFORCE_EQ(x_dims[i] == 1 || y_dims[i - axis] == 1, - true, - phi::errors::InvalidArgument( - "Cannot broadcast input shape where " - "x_dims[%d] = %d, y_dims[%d] = %d.", - i, - x_dims[i], - i - axis, - y_dims[i - axis])); - } - broadcast_out_dims[i] = std::max(x_dims[i], y_dims[i - axis]); - x_out_dims[i] = x_dims[i]; - y_out_dims[i] = y_dims[i - axis]; - } else { - broadcast_out_dims[i] = x_dims[i]; - x_out_dims[i] = x_dims[i]; - y_out_dims[i] = 1; - } - } - } else { - CalculateBroadcastDims(y_dims, - x_dims, - ny, - nx, - axis, - y_out_dims, - x_out_dims, - broadcast_out_dims, - length); - } - } - - static bool IsSameShape(const int64_t *x, const int64_t *y, int rank) { - for (int i = 0; i < rank; ++i) { - if (x[i] != y[i]) return false; - } - return true; - } - - static phi::Array ShapeToStride( - const int64_t *arr, int rank) { - phi::Array strides; - strides[rank] = 1; - for (int i = rank - 1; i >= 0; --i) { - strides[i] = strides[i + 1] * arr[i]; - } - return strides; - } -}; -#endif - template typename std::enable_if::value, void>::type BroadcastKernelForDifferentVecSize(const KPDevice &ctx, @@ -825,25 +518,6 @@ BroadcastKernelForDifferentVecSize(const KPDevice &ctx, std::vector *outs, int axis, Functor func) { -#ifndef PADDLE_WITH_XPU_KP - constexpr bool kEnabledInt64IndexKernel = (NumOuts == 1 && Arity <= 3); - bool use_int64_index_kernel = - kEnabledInt64IndexKernel && - (*outs)[0]->numel() >= std::numeric_limits::max(); - if (use_int64_index_kernel) { - LaunchBroadcastKernelWithInt64IndexHelper::Run(ctx, - ins, - outs, - axis, - func); - return; - } -#endif - auto classifier = BroadcastTypeClassifier(ins, outs, axis); LaunchBroadcastKernel( diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cc b/paddle/phi/kernels/funcs/gather_scatter_functor.cc index 2b667c32d9db3f..597b8f231760bf 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cc +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cc @@ -92,7 +92,7 @@ struct cpu_gather_scatter_functor { outer_dim_size *= index_dims[i]; } int64_t index_idx = 0; - int64_t self_idx, src_idx; + int64_t self_idx = 0, src_idx = 0; // N layer loop squeezed into 3 layers loop for (int64_t i = 0; i < inner_dim_size; i++) { diff --git a/paddle/phi/kernels/funcs/gpc.cc b/paddle/phi/kernels/funcs/gpc.cc index b3199d88f5888e..47a3001b4fda2d 100644 --- a/paddle/phi/kernels/funcs/gpc.cc +++ b/paddle/phi/kernels/funcs/gpc.cc @@ -87,7 +87,7 @@ const std::array, 3> next_h_state = { */ static void reset_it(it_node **it) { - it_node *itn; + it_node *itn = nullptr; while (*it) { itn = (*it)->next; @@ -97,7 +97,7 @@ static void reset_it(it_node **it) { } static void reset_lmt(lmt_node **lmt) { - lmt_node *lmtn; + lmt_node *lmtn = nullptr; while (*lmt) { lmtn = (*lmt)->next; @@ -140,7 +140,7 @@ static void insert_bound(edge_node **b, edge_node *e) { } static edge_node **bound_list(lmt_node **lmt, double y) { - lmt_node *existing_node; + lmt_node *existing_node = nullptr; if (!*lmt) { /* Add node onto the tail end of the LMT */ @@ -407,7 +407,7 @@ static void add_edge_to_aet(edge_node **aet, edge_node *edge, edge_node *prev) { static void add_intersection( it_node **it, edge_node *edge0, edge_node *edge1, double x, double y) { - it_node *existing_node; + it_node *existing_node = nullptr; if (!*it) { /* Append a new node to the tail of the list */ @@ -440,7 +440,7 @@ static void add_st_edge(st_node **st, it_node **it, edge_node *edge, double dy) { - st_node *existing_node; + st_node *existing_node = nullptr; double den = 0.0; double r = 0.0; double x = 0.0; @@ -486,8 +486,8 @@ static void add_st_edge(st_node **st, } static void build_intersection_table(it_node **it, edge_node *aet, double dy) { - st_node *st; - st_node *stp; + st_node *st = nullptr; + st_node *stp = nullptr; edge_node *edge = nullptr; /* Build intersection table for the current scanbeam */ @@ -706,7 +706,7 @@ static void new_tristrip(polygon_node **tn, } static bbox *create_contour_bboxes(gpc_polygon *p) { - bbox *box; + bbox *box = nullptr; int c = 0; int v = 0; @@ -744,8 +744,8 @@ static bbox *create_contour_bboxes(gpc_polygon *p) { } static void minimax_test(gpc_polygon *subj, gpc_polygon *clip, gpc_op op) { - bbox *s_bbox; - bbox *c_bbox; + bbox *s_bbox = nullptr; + bbox *c_bbox = nullptr; int s = 0; int c = 0; int *o_table = nullptr; diff --git a/paddle/phi/kernels/funcs/im2col.cc b/paddle/phi/kernels/funcs/im2col.cc index e4c470e1a7064f..44dd15ead335be 100644 --- a/paddle/phi/kernels/funcs/im2col.cc +++ b/paddle/phi/kernels/funcs/im2col.cc @@ -137,7 +137,7 @@ class Col2ImFunctor { int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; if ((im_row_idx) >= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { - int im_offset; + int im_offset = 0; if (data_layout != DataLayout::kNHWC) { im_offset = (c_im * im_height + im_row_idx) * im_width + im_col_idx; diff --git a/paddle/phi/kernels/funcs/jit/benchmark.cc b/paddle/phi/kernels/funcs/jit/benchmark.cc index 83a9a4a45d643f..894a711ddec6d7 100644 --- a/paddle/phi/kernels/funcs/jit/benchmark.cc +++ b/paddle/phi/kernels/funcs/jit/benchmark.cc @@ -113,7 +113,7 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { BenchFunc benchmark; std::vector> infos; auto funcs = jit::GetAllCandidateFuncsWithTypes(attr); - for (auto f : funcs) { + for (auto const& f : funcs) { infos.push_back(std::make_pair(f.first, benchmark(f.second, args...))); } @@ -128,7 +128,7 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { std::ostringstream loginfos; loginfos << "Kernel Type " << jit::to_string(KernelTuple::kernel_type) << ": " << attr << ": "; - for (auto pair : infos) { + for (auto const& pair : infos) { loginfos << pair.first << " takes " << pair.second << " us; "; } LOG(INFO) << loginfos.str(); diff --git a/paddle/phi/kernels/funcs/jit/gen_base.cc b/paddle/phi/kernels/funcs/jit/gen_base.cc index a80f9817c476ac..3758aaf4cace8d 100644 --- a/paddle/phi/kernels/funcs/jit/gen_base.cc +++ b/paddle/phi/kernels/funcs/jit/gen_base.cc @@ -47,7 +47,7 @@ void GenBase::dumpCode(const unsigned char* code) const { } void* GenBase::operator new(size_t size) { - void* ptr; + void* ptr = nullptr; constexpr size_t alignment = 32ul; #ifdef _WIN32 ptr = _aligned_malloc(size, alignment); @@ -71,8 +71,8 @@ void GenBase::operator delete(void* ptr) { } std::vector packed_groups(int n, int k, int* block_out, int* rest_out) { - int block; - int max_num_regs; + int block = 0; + int max_num_regs = 0; if (phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f)) { block = ZMM_FLOAT_BLOCK; max_num_regs = 32; diff --git a/paddle/phi/kernels/funcs/jit/helper.cc b/paddle/phi/kernels/funcs/jit/helper.cc index 5c93637649f897..c135d6ee3177dd 100644 --- a/paddle/phi/kernels/funcs/jit/helper.cc +++ b/paddle/phi/kernels/funcs/jit/helper.cc @@ -104,7 +104,7 @@ KernelType to_kerneltype(const std::string& act) { template <> void pack_weights(const float* src, float* dst, int n, int k) { - int block, rest; + int block = 0, rest = 0; const auto groups = packed_groups(n, k, &block, &rest); std::for_each(groups.begin(), groups.end(), [&](int i) { PADDLE_ENFORCE_GT(i, diff --git a/paddle/phi/kernels/funcs/jit/more/intrinsic/layer_norm.cc b/paddle/phi/kernels/funcs/jit/more/intrinsic/layer_norm.cc index d7d62d6815501a..4b50de277a9c28 100644 --- a/paddle/phi/kernels/funcs/jit/more/intrinsic/layer_norm.cc +++ b/paddle/phi/kernels/funcs/jit/more/intrinsic/layer_norm.cc @@ -44,8 +44,8 @@ void LayerNorm(float* x, __m256 mean_vec, var_vec; __m128 hi, lo; __m256 tmp = _mm256_setzero_ps(); - size_t offset; - size_t j; + size_t offset = 0; + size_t j = 0; __m256 reverse_num_vec = _mm256_div_ps( _mm256_set1_ps(1.0), _mm256_set1_ps(static_cast(right))); __m256 epsilon_vec = _mm256_set1_ps(epsilon); diff --git a/paddle/phi/kernels/funcs/maxouting.cc b/paddle/phi/kernels/funcs/maxouting.cc index 40b184865a5202..9c32453511f75d 100644 --- a/paddle/phi/kernels/funcs/maxouting.cc +++ b/paddle/phi/kernels/funcs/maxouting.cc @@ -43,7 +43,7 @@ void MaxOutFunctor::operator()(const DeviceContext& context, int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { T ele = static_cast(-FLT_MAX); - int input_idx, output_idx; + int input_idx = 0, output_idx = 0; for (int ph = 0; ph < groups; ++ph) { if (axis == 1) { input_idx = (new_bindex + new_cindex) * groups + ph * fea_size + f; @@ -89,7 +89,7 @@ void MaxOutGradFunctor::operator()( for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; for (int f = 0; f < fea_size; ++f) { - int input_idx0, output_idx; + int input_idx0 = 0, output_idx = 0; bool continue_match = true; if (axis == 1) { input_idx0 = (blen + clen) * groups + f; diff --git a/paddle/phi/kernels/funcs/pooling.cc b/paddle/phi/kernels/funcs/pooling.cc index ae68da49653fff..0573430c2010c5 100644 --- a/paddle/phi/kernels/funcs/pooling.cc +++ b/paddle/phi/kernels/funcs/pooling.cc @@ -1592,8 +1592,8 @@ class MaxPool2dWithIndexFunctor { T1* output_data = context.template Alloc(output); T2* mask_data = context.template Alloc(mask); - int hstart, hend; - int wstart, wend; + int hstart = 0, hend = 0; + int wstart = 0, wend = 0; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < output_height; ++ph) { @@ -1730,9 +1730,9 @@ class MaxPool3dWithIndexFunctor { T1* output_data = context.template Alloc(output); T2* mask_data = context.template Alloc(mask); - int dstart, dend; - int hstart, hend; - int wstart, wend; + int dstart = 0, dend = 0; + int hstart = 0, hend = 0; + int wstart = 0, wend = 0; for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { for (int pd = 0; pd < output_depth; ++pd) { diff --git a/paddle/phi/kernels/funcs/tensor_formatter.cc b/paddle/phi/kernels/funcs/tensor_formatter.cc index 0b9d4f31d553e3..16d3b38bced7c1 100644 --- a/paddle/phi/kernels/funcs/tensor_formatter.cc +++ b/paddle/phi/kernels/funcs/tensor_formatter.cc @@ -66,7 +66,7 @@ std::string TensorFormatter::Format(const phi::DenseTensor& print_tensor, if (print_tensor_lod_) { log_stream << " - lod: {"; const phi::LoD& lod = print_tensor.lod(); - for (auto level : lod) { + for (auto const& level : lod) { log_stream << "{"; bool is_first = true; for (auto i : level) { diff --git a/paddle/phi/kernels/funcs/vol2col.cc b/paddle/phi/kernels/funcs/vol2col.cc index e505fcb3de3372..b5d6086feda770 100644 --- a/paddle/phi/kernels/funcs/vol2col.cc +++ b/paddle/phi/kernels/funcs/vol2col.cc @@ -123,7 +123,7 @@ class Vol2ColFunctor { int64_t col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; - int64_t vol_idx; + int64_t vol_idx = 0; if (data_layout != DataLayout::kNHWC) { vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) * input_width + @@ -248,7 +248,7 @@ class Col2VolFunctor { if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { - int vol_idx; + int vol_idx = 0; if (data_layout != DataLayout::kNHWC) { vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) * input_width + diff --git a/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h b/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h index c98a5f69ae0d6a..44c02338f8b543 100644 --- a/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h +++ b/paddle/phi/kernels/fused_bn_add_activation_grad_kernel.h @@ -21,13 +21,13 @@ namespace phi { template void FusedBatchNormAddActGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &y_grad, const DenseTensor &scale, const DenseTensor &bias, + const DenseTensor &y, const DenseTensor &saved_mean, const DenseTensor &saved_variance, const DenseTensor &reserve_space, + const DenseTensor &y_grad, float momentum, float epsilon, const std::string &act_type, diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py index cbe4571c5d010c..7caf30236bb79e 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py @@ -39,8 +39,8 @@ def find_arch_range(min_arch, max_arch): - assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH - assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH + assert min_arch >= DEFAULT_ARCH[0] and min_arch <= MAX_ARCH + assert max_arch >= DEFAULT_ARCH[0] and max_arch <= MAX_ARCH assert min_arch <= max_arch n = len(DEFAULT_ARCH) diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py index cd21c12a4323a0..8dd51f0c797a43 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_variable_forward_kernels.py @@ -39,8 +39,8 @@ def find_arch_range(min_arch, max_arch): - assert min_arch >= DEFAULT_ARCH[0] and min_arch < MAX_ARCH - assert max_arch >= DEFAULT_ARCH[0] and max_arch < MAX_ARCH + assert min_arch >= DEFAULT_ARCH[0] and min_arch <= MAX_ARCH + assert max_arch >= DEFAULT_ARCH[0] and max_arch <= MAX_ARCH assert min_arch <= max_arch n = len(DEFAULT_ARCH) diff --git a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu index e19b468b54a355..3b9618db02db05 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bn_add_activation_grad_kernel.cu @@ -45,13 +45,13 @@ using BatchNormParamType = typename CudnnDataType::BatchNormParamType; template void FusedBatchNormAddActGradKernel(const Context &dev_ctx, const DenseTensor &x, - const DenseTensor &y, - const DenseTensor &y_grad, const DenseTensor &scale, const DenseTensor &bias, + const DenseTensor &y, const DenseTensor &saved_mean, const DenseTensor &saved_variance, const DenseTensor &reserve_space, + const DenseTensor &y_grad, float momentum, float epsilon, const std::string &act_type, diff --git a/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu index 0344a71b970622..71e778ca6574e4 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_embedding_eltwise_layernorm_kernel.cu @@ -16,6 +16,8 @@ #include #include "paddle/phi/common/float16.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" #include "paddle/phi/core/kernel_registry.h" @@ -48,12 +50,6 @@ void EmbeddingEltWiseLayerNormKernel( DenseTensor in_ids_(phi::DataType::INT64), in_embs_(phi::DataType::INT64); DDim in_dim{input_num}; - int device_id; -#ifdef PADDLE_WITH_HIP - hipGetDevice(&device_id); -#else - cudaGetDevice(&device_id); -#endif in_ids_.Resize(in_dim); in_embs_.Resize(in_dim); @@ -68,29 +64,19 @@ void EmbeddingEltWiseLayerNormKernel( in1s.push_back(reinterpret_cast(ids[i]->data())); in2s.push_back(reinterpret_cast(embs[i]->data())); } -#ifdef PADDLE_WITH_HIP - hipMemcpyAsync(in_ids_d, - in1s.data(), - sizeof(int64_t) * input_num, - hipMemcpyHostToDevice, - dev_ctx.stream()); - hipMemcpyAsync(in_embs_d, - in2s.data(), - sizeof(int64_t) * input_num, - hipMemcpyHostToDevice, - dev_ctx.stream()); -#else - cudaMemcpyAsync(in_ids_d, - in1s.data(), - sizeof(int64_t) * input_num, - cudaMemcpyHostToDevice, - dev_ctx.stream()); - cudaMemcpyAsync(in_embs_d, - in2s.data(), - sizeof(int64_t) * input_num, - cudaMemcpyHostToDevice, - dev_ctx.stream()); -#endif + + phi::memory_utils::Copy(phi::GPUPlace{}, + in_ids_d, + phi::CPUPlace{}, + in1s.data(), + sizeof(int64_t) * input_num, + dev_ctx.stream()); + phi::memory_utils::Copy(phi::GPUPlace{}, + in_embs_d, + phi::CPUPlace{}, + in2s.data(), + sizeof(int64_t) * input_num, + dev_ctx.stream()); // should be (B * S * hidden) auto id0_dims = ids[0]->dims(); diff --git a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h index ed311e520681f0..12e64caa54b0a6 100644 --- a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h +++ b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h @@ -1325,16 +1325,18 @@ inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT k.y = rotary_embedding_transform(k.y, cos.y, sin.x); } -inline __device__ void apply_rotary_embedding(uint2& q, // NOLINT - uint2& k, // NOLINT - float4& cos, // NOLINT - float4& sin) { // NOLINT +inline __device__ void apply_rotary_embedding( + uint2& q, // NOLINT equals 4 half. + uint2& k, // NOLINT + float4& cos, // NOLINT 2 float2 cos. + float4& sin) { // NOLINT Float4_& cos_ = *reinterpret_cast(&cos); Float4_& sin_ = *reinterpret_cast(&sin); + // cos_.x is float2 q.x = rotary_embedding_transform(q.x, cos_.x, sin_.x); k.x = rotary_embedding_transform(k.x, cos_.x, sin_.x); q.y = rotary_embedding_transform(q.y, cos_.y, sin_.y); - k.y = rotary_embedding_transform(k.y, cos_.y, sin_.x); + k.y = rotary_embedding_transform(k.y, cos_.y, sin_.y); } inline __device__ void apply_rotary_embedding(uint4& q, // NOLINT diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index b65eaa5d7757d1..c67864bc13f573 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -381,9 +381,10 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, - SoftplusDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, + SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index acfe4dd5a2941b..6eeba717ece0dd 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -250,7 +250,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) -PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) PD_REGISTER_KERNEL(exp, GPU, diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index 41f5f4c3f4d051..caa635255b9878 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -209,9 +209,9 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { - if (dtype < 0) { + if (dtype == DataType::UNDEFINED) { phi::VisitDataTypeTiny( phi::DataType::INT64, VisitDataCudaArgMinMaxFunctor( @@ -219,7 +219,7 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, return; } phi::VisitDataTypeTiny( - phi::TransToPhiDataType(dtype), + dtype, VisitDataCudaArgMinMaxFunctor( dev_ctx, x, axis.to(), keepdims, flatten, out)); } @@ -230,7 +230,7 @@ void ArgMinKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { ArgMinMaxOpCUDAKernel( dev_ctx, x, axis, keepdims, flatten, dtype, out); @@ -242,7 +242,7 @@ void ArgMaxKernel(const Context& dev_ctx, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { ArgMinMaxOpCUDAKernel( dev_ctx, x, axis, keepdims, flatten, dtype, out); diff --git a/paddle/phi/kernels/gpu/contiguous_kernel.cu b/paddle/phi/kernels/gpu/contiguous_kernel.cu index b8dee10e31cdeb..357e104afb01c8 100644 --- a/paddle/phi/kernels/gpu/contiguous_kernel.cu +++ b/paddle/phi/kernels/gpu/contiguous_kernel.cu @@ -20,26 +20,120 @@ limitations under the License. */ #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { +template +__global__ void ContiguousCaseZeroFunc( + const T* input_data, + T* out_data, + phi::Array input_stride) { + int64_t input_offset = 0; + int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; +} template -__global__ void ContiguousFunc( +__global__ void ContiguousCaseOneFunc( const T* input_data, T* out_data, phi::Array input_stride, - phi::Array dims, - const int64_t numel) { - int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; -#pragma unroll - for (int64_t i = gid; i < numel; i += blockDim.x * gridDim.x) { + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { int64_t input_offset = 0; - int64_t index_tmp = i; + int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + #pragma unroll for (int dim = N - 1; dim >= 0; --dim) { - input_offset += index_tmp % dims[dim] * input_stride[dim]; - index_tmp = index_tmp / dims[dim]; + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; } - out_data[i] = input_data[input_offset]; + out_data[output_offset] = input_data[input_offset]; } } @@ -135,49 +229,214 @@ void ContiguousKernel(const Context& dev_ctx, input_stride[0] = 1; } - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; - - switch (rank) { - case 1: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 2: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 3: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 4: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 5: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 6: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 7: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 8: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - case 9: - ContiguousFunc<<>>( - input_data, output_data, input_stride, input_dims, numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", rank)); + dim3 grid(1, 1, 1), block(1, 1, 1); + + int tmp = 1; + + for (int i = 0; i < 3 && i < rank; i++) { + tmp *= input_dims[rank - 1 - i]; + } + + if (rank <= 6 && tmp <= 1024 && + (input_dims.size() < 3 || input_dims[rank - 3] <= 64)) { + if (rank >= 1) { + block.x = input_dims[rank - 1]; + } + + if (rank >= 2) { + block.y = input_dims[rank - 2]; + } + + if (rank >= 3) { + block.z = input_dims[rank - 3]; + } + + switch (rank) { + case 1: + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 2: + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 3: + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 4: + grid.x = input_dims[rank - 4]; + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + ContiguousCaseZeroFunc<<>>( + input_data, output_data, input_stride); + break; + } + } else { + phi::Array cur_input_dims; + block.x = 512; + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + ContiguousCaseOneFunc + <<>>(input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = + input_dims[rank - 4] * input_dims[rank - 5] * input_dims[rank - 6]; + grid.z = + input_dims[rank - 7] * input_dims[rank - 8] * input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + ContiguousCaseOneFunc<<>>( + input_data, + output_data, + input_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", rank)); + } } } diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index 8829d32596be1e..bd1d7db96cfeca 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -44,6 +44,7 @@ void FullKernel(const Context& dev_ctx, out->Resize(phi::make_ddim(shape.GetData())); int numel = out->numel(); dev_ctx.template Alloc(out); + if (numel > 0) { // in transformer model the numel of outpout will be zero. std::vector inputs = {}; diff --git a/paddle/phi/kernels/gpu/multinomial_kernel.cu b/paddle/phi/kernels/gpu/multinomial_kernel.cu index ba137b6fadc761..96fc3d1ac2b2e5 100644 --- a/paddle/phi/kernels/gpu/multinomial_kernel.cu +++ b/paddle/phi/kernels/gpu/multinomial_kernel.cu @@ -191,7 +191,7 @@ void MultinomialKernel(const Context& dev_ctx, if (int_num_samples == 1) { ArgMaxKernel( - dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out); + dev_ctx, rand, -1, true, false, DataType::INT64, out); } else { std::vector out_dim_vec = vectorize(out->dims()); DenseTensor value = Empty(dev_ctx, IntArray(out_dim_vec)); diff --git a/paddle/phi/kernels/gpu/strided_copy_kernel.cu b/paddle/phi/kernels/gpu/strided_copy_kernel.cu index 65dae3fc89efe9..e72eca2f936e19 100644 --- a/paddle/phi/kernels/gpu/strided_copy_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_copy_kernel.cu @@ -48,6 +48,127 @@ __global__ void StridedCopyFunc( } } +template +__global__ void StridedCopyCaseZeroFunc( + const T* input_data, + phi::Array input_stride, + T* output_data, + phi::Array output_stride) { + int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + int64_t output_offset = input_offset; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = RANK - 1; dim >= 0; --dim) { + input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; + output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; + } + + output_data[output_offset] = input_data[input_offset]; +} + +template +__global__ void StridedCopyCaseOneFunc( + const T* input_data, + phi::Array input_stride, + T* out_data, + phi::Array output_stride, + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { + int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + int64_t output_offset = input_offset; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; + output_offset += coordinate[N - 1 - dim] * output_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; + } +} + template __global__ void Strided2ContiguousFunc( const T* input_data, @@ -71,6 +192,123 @@ __global__ void Strided2ContiguousFunc( } } +template +__global__ void Strided2ContiguousCaseZeroFunc( + const T* input_data, + phi::Array input_stride, + T* output_data) { + int64_t input_offset = 0; + int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = RANK - 1; dim >= 0; --dim) { + input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; + } + + output_data[output_offset] = input_data[input_offset]; +} + +template +__global__ void Strided2ContiguousCaseOneFunc( + const T* input_data, + phi::Array input_stride, + T* out_data, + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { + int64_t input_offset = 0; + int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; + } +} + template __global__ void Contiguous2StridedFunc( const T* input_data, @@ -94,6 +332,123 @@ __global__ void Contiguous2StridedFunc( } } +template +__global__ void Contiguous2StridedCaseZeroFunc( + const T* input_data, + T* output_data, + phi::Array output_stride) { + int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + int64_t output_offset = 0; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = RANK - 1; dim >= 0; --dim) { + output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; + } + + output_data[output_offset] = input_data[input_offset]; +} + +template +__global__ void Contiguous2StridedCaseOneFunc( + const T* input_data, + T* out_data, + phi::Array output_stride, + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { + int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + int64_t output_offset = 0; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + output_offset += coordinate[N - 1 - dim] * output_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; + } +} + template void StridedCopyKernel(const Context& dev_ctx, const DenseTensor& input, @@ -145,8 +500,6 @@ void StridedCopyKernel(const Context& dev_ctx, } auto numel = input.numel(); - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; if (numel == 1) { #ifdef PADDLE_WITH_HIP @@ -164,1088 +517,649 @@ void StridedCopyKernel(const Context& dev_ctx, return; } - if (input.meta().is_contiguous()) { - switch (input_rank) { - case 1: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - input_rank)); + dim3 grid(1, 1, 1), block(1, 1, 1); + int rank = input_rank; + int tmp = 1; + + for (int i = 0; i < 3 && i < rank; i++) { + tmp *= input_dims[rank - 1 - i]; + } + + if (rank <= 6 && tmp <= 1024 && + (input_dims.size() < 3 || input_dims[rank - 3] <= 64)) { + if (rank >= 1) { + block.x = input_dims[rank - 1]; } - } else if (out->meta().is_contiguous()) { - switch (output_rank) { - case 1: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); + + if (rank >= 2) { + block.y = input_dims[rank - 2]; + } + + if (rank >= 3) { + block.z = input_dims[rank - 3]; + } + + if (input.meta().is_contiguous()) { + switch (rank) { + case 1: + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 2: + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 3: + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 4: + grid.x = input_dims[rank - 4]; + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + } + } else if (out->meta().is_contiguous()) { + switch (rank) { + case 1: + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 2: + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 3: + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 4: + grid.x = input_dims[rank - 4]; + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + } + } else { + switch (rank) { + case 1: + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 2: + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 3: + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 4: + grid.x = input_dims[rank - 4]; + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + } } } else { - switch (input_rank) { - case 1: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 2: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 3: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 4: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 5: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 6: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 7: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 8: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 9: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - input_rank)); + phi::Array cur_input_dims; + block.x = 512; + + if (input.meta().is_contiguous()) { + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Contiguous2StridedCaseOneFunc + <<>>( + input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8] * + input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + rank)); + } + } else if (out->meta().is_contiguous()) { + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Strided2ContiguousCaseOneFunc + <<>>( + input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8] * + input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + rank)); + } + } else { + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + StridedCopyCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8] * + input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + rank)); + } } } } diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index 2333de4b2e02a9..59f63c5d8cae5b 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -38,6 +38,8 @@ PD_REGISTER_KERNEL(mean, float, double, bool, + int, + int64_t, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index 78d34fa14295c8..1deddcf6dc0faf 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -118,7 +118,7 @@ void AdamDenseParamSparseGradKernel( } phi::SelectedRows tmp_grad_merge; - const phi::SelectedRows* grad_merge_ptr; + const phi::SelectedRows* grad_merge_ptr = nullptr; if (is_strict_sorted) { grad_merge_ptr = &grad; } else { diff --git a/paddle/phi/kernels/stride/diagonal_kernel.cc b/paddle/phi/kernels/stride/diagonal_kernel.cc index e8929e6773f533..b4ca6d9b277df5 100644 --- a/paddle/phi/kernels/stride/diagonal_kernel.cc +++ b/paddle/phi/kernels/stride/diagonal_kernel.cc @@ -36,7 +36,7 @@ void DiagonalStridedKernel(const Context& dev_ctx, axis2 += static_cast(x_rank); } - int64_t diag_size; + int64_t diag_size = 0; int64_t x_offset = static_cast(x.offset()); if (offset >= 0) { diag_size = std::max( diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc index 2b637e9da09e86..b5b2ed7d328884 100644 --- a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -22,23 +22,18 @@ namespace phi { -namespace { -const int ARG_MAX_OUTPUT_DATATYPE_INT32 = 2; -const int ARG_MAX_OUTPUT_DATATYPE_INT64 = 3; -} // Anonymous namespace - template void ArgMaxKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& axis, bool keepdims, bool flatten, - int dtype, + DataType dtype, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; PADDLE_ENFORCE_EQ( - (dtype < 0 || dtype == ARG_MAX_OUTPUT_DATATYPE_INT32 || - dtype == ARG_MAX_OUTPUT_DATATYPE_INT64), + (dtype == DataType::UNDEFINED || dtype == DataType::INT32 || + dtype == DataType::INT64), true, errors::InvalidArgument( "The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], " @@ -60,7 +55,7 @@ void ArgMaxKernel(const Context& dev_ctx, } auto xdims_vec = phi::vectorize(x_dims); int r = 0; - if (dtype != ARG_MAX_OUTPUT_DATATYPE_INT32) { + if (dtype != DataType::INT32) { dev_ctx.template Alloc(out); if (x.dims().size() == 0) { xpu::constant(dev_ctx.x_context(), diff --git a/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc b/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc index a9adffca84700b..c32175b856397d 100644 --- a/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc +++ b/paddle/phi/ops/compat/fused_bn_add_activation_sig.cc @@ -33,13 +33,13 @@ KernelSignature FusedBatchNormAddActGradOpArgumentMapping( const ArgumentMappingContext& ctx UNUSED) { return KernelSignature("fused_bn_add_activation_grad", {"X", - "Y", - "Y@GRAD", "Scale", "Bias", + "Y", "SavedMean", "SavedVariance", - "ReserveSpace"}, + "ReserveSpace", + "Y@GRAD"}, {"momentum", "epsilon", "act_type"}, {"X@GRAD", "Z@GRAD", "Scale@GRAD", "Bias@GRAD"}); } diff --git a/paddle/phi/ops/compat/strided_slice_sig.cc b/paddle/phi/ops/compat/strided_slice_sig.cc index 02b39147878661..0c0e5d0c868f4e 100644 --- a/paddle/phi/ops/compat/strided_slice_sig.cc +++ b/paddle/phi/ops/compat/strided_slice_sig.cc @@ -57,7 +57,7 @@ KernelSignature StridedSliceOpArgumentMapping( "decrease_axis"}; paddle::small_vector outputs = {"Out"}; - const char* kernel_name; + const char* kernel_name = nullptr; if (ctx.IsDenseTensorVectorInput("Input")) { kernel_name = "strided_slice_array"; } else { @@ -106,7 +106,7 @@ KernelSignature StridedSliceGradOpArgumentMapping( "decrease_axis"}; paddle::small_vector outputs = {"Input@GRAD"}; - const char* kernel_name; + const char* kernel_name = nullptr; if (ctx.IsDenseTensorVectorInput("Input")) { kernel_name = "strided_slice_array_grad"; } else { diff --git a/paddle/pir/core/block.h b/paddle/pir/core/block.h index b7a730715c12c5..c61a8f22e54256 100644 --- a/paddle/pir/core/block.h +++ b/paddle/pir/core/block.h @@ -84,6 +84,7 @@ class IR_API Block { ArgsIterator args_end() { return arguments_.end(); } bool args_empty() const { return arguments_.empty(); } uint32_t args_size() const { return arguments_.size(); } + const BlockArgListType &args() const { return arguments_; } BlockArgument argument(uint32_t index) { return arguments_[index]; } Type argument_type(uint32_t index) const { return arguments_[index].type(); } void ClearArguments(); diff --git a/paddle/pir/core/builtin_dialect.cc b/paddle/pir/core/builtin_dialect.cc index 23ba43c3d292ec..60575da6d9472c 100644 --- a/paddle/pir/core/builtin_dialect.cc +++ b/paddle/pir/core/builtin_dialect.cc @@ -53,6 +53,7 @@ void BuiltinDialect::initialize() { RegisterOpsattributes(); + auto iter = attributes.find("output_name"); + IR_ENFORCE(iter != attributes.end() && iter->second.isa(), + "Type of attribute: output_name is not right."); + + // Verify outputs: + IR_ENFORCE(num_results() == 0u, "The size of outputs must be equal to 0."); +} + void CombineOp::Build(Builder &builder, OperationArgument &argument, const std::vector &inputs) { @@ -172,7 +198,7 @@ void CombineOp::Build(Builder &builder, PassStopGradientsDefaultly(argument); } -void CombineOp::Verify() const { +void CombineOp::VerifySig() const { // outputs.size() == 1 IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); @@ -234,7 +260,7 @@ void SliceOp::PassStopGradients(OperationArgument &argument, int index) { pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } -void SliceOp::Verify() const { +void SliceOp::VerifySig() const { // inputs.size() == 1 auto input_size = num_operands(); IR_ENFORCE( @@ -338,7 +364,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) { pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } -void SplitOp::Verify() const { +void SplitOp::VerifySig() const { // inputs.size() == 1 IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1."); @@ -367,7 +393,7 @@ void ConstantOp::Build(Builder &builder, argument.output_types.push_back(output_type); } -void ConstantOp::Verify() const { +void ConstantOp::VerifySig() const { IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(attributes().count("value") > 0, "must has value attribute"); @@ -380,6 +406,7 @@ Attribute ConstantOp::value() const { return attributes().at("value"); } IR_DEFINE_EXPLICIT_TYPE_ID(pir::ModuleOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::GetParameterOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SetParameterOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::ShadowOutputOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::CombineOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SliceOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::SplitOp) diff --git a/paddle/pir/core/builtin_op.h b/paddle/pir/core/builtin_op.h index e5327f4c5db45e..19ca96b0526928 100644 --- a/paddle/pir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -31,7 +31,7 @@ class IR_API ModuleOp : public pir::Op { static const char *name() { return "builtin.module"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - void Verify() const; + void VerifySig() const; Program *program(); Block *block(); @@ -56,7 +56,7 @@ class IR_API GetParameterOp : public pir::Op { OperationArgument &argument, // NOLINT const std::string &name, Type type); - void Verify() const; + void VerifySig() const; private: static void PassStopGradients(OperationArgument &argument); // NOLINT @@ -76,7 +76,24 @@ class IR_API SetParameterOp : public pir::Op { OperationArgument &argument, // NOLINT Value parameter, const std::string &name); - void Verify() const; + void VerifySig() const; +}; + +/// +/// \brief ShdowOutputOp: ShdowOutputOp(OpOperand, {StrAttribute, +/// StrAttribute}) +/// +class IR_API ShadowOutputOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "builtin.shadow_output"; } + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + Value parameter, + const std::string &name); + void VerifySig() const; }; /// @@ -96,7 +113,7 @@ class IR_API CombineOp : public pir::Op { OperationArgument &argument, // NOLINT const std::vector &inputs); - void Verify() const; + void VerifySig() const; std::vector inputs() { std::vector inputs; for (uint32_t idx = 0; idx < num_operands(); idx++) { @@ -125,7 +142,7 @@ class IR_API SliceOp : public pir::Op { Value input, int index); - void Verify() const; + void VerifySig() const; pir::Value input() { return operand_source(0); } private: @@ -150,7 +167,7 @@ class IR_API SplitOp : public pir::Op { OperationArgument &argument, // NOLINT Value input); - void Verify() const; + void VerifySig() const; pir::Value input() { return operand_source(0); } std::vector outputs() { std::vector res; @@ -186,7 +203,7 @@ class IR_API ConstantOp : public Op { Attribute value, Type output_type); - void Verify() const; + void VerifySig() const; Attribute value() const; }; @@ -198,6 +215,7 @@ void PassStopGradientsDefaultly(OperationArgument &argument); // NOLINT IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ModuleOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::GetParameterOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::SetParameterOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ShadowOutputOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CombineOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::SliceOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::SplitOp) diff --git a/paddle/pir/core/dialect.h b/paddle/pir/core/dialect.h index 07debaf1960410..8c66f3c1d6a159 100644 --- a/paddle/pir/core/dialect.h +++ b/paddle/pir/core/dialect.h @@ -100,7 +100,8 @@ class IR_API Dialect { ConcreteOp::GetTraitSet(), ConcreteOp::attributes_num, ConcreteOp::attributes_name, - ConcreteOp::VerifyInvariants); + ConcreteOp::VerifySigInvariants, + ConcreteOp::VerifyRegionInvariants); } void RegisterOp(const std::string &name, OpInfoImpl *op_info); diff --git a/paddle/pir/core/ir_context.cc b/paddle/pir/core/ir_context.cc index cab574c68d1f64..1ebd9e4f0c6423 100644 --- a/paddle/pir/core/ir_context.cc +++ b/paddle/pir/core/ir_context.cc @@ -226,7 +226,8 @@ void IrContext::RegisterAbstractAttribute( pir::TypeId type_id, AbstractAttribute &&abstract_attribute) { if (GetRegisteredAbstractAttribute(type_id) == nullptr) { impl().RegisterAbstractAttribute( - type_id, new AbstractAttribute(std::move(abstract_attribute))); + type_id, + new AbstractAttribute(std::move(abstract_attribute))); // NOLINT } else { LOG(WARNING) << " Attribute already registered."; } @@ -258,14 +259,14 @@ Dialect *IrContext::GetOrRegisterDialect( std::vector IrContext::GetRegisteredDialects() { std::vector result; - for (auto dialect_map : impl().registed_dialect_) { + for (auto const &dialect_map : impl().registed_dialect_) { result.push_back(dialect_map.second); } return result; } Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { - for (auto dialect_map : impl().registed_dialect_) { + for (auto const &dialect_map : impl().registed_dialect_) { if (dialect_map.first == dialect_name) { return dialect_map.second; } @@ -277,8 +278,8 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) { void IrContext::RegisterAbstractType(pir::TypeId type_id, AbstractType &&abstract_type) { if (GetRegisteredAbstractType(type_id) == nullptr) { - impl().RegisterAbstractType(type_id, - new AbstractType(std::move(abstract_type))); + impl().RegisterAbstractType( + type_id, new AbstractType(std::move(abstract_type))); // NOLINT } else { LOG(WARNING) << " type already registered."; } @@ -291,7 +292,8 @@ void IrContext::RegisterOpInfo(Dialect *dialect, const std::vector &trait_set, size_t attributes_num, const char **attributes_name, - VerifyPtr verify) { + VerifyPtr verify_sig, + VerifyPtr verify_region) { if (impl().IsOpInfoRegistered(name)) { LOG(WARNING) << name << " op already registered."; } else { @@ -302,7 +304,8 @@ void IrContext::RegisterOpInfo(Dialect *dialect, trait_set, attributes_num, attributes_name, - verify); + verify_sig, + verify_region); impl().RegisterOpInfo(name, info); } } diff --git a/paddle/pir/core/ir_context.h b/paddle/pir/core/ir_context.h index d459f915242290..c20a0d7bba2925 100644 --- a/paddle/pir/core/ir_context.h +++ b/paddle/pir/core/ir_context.h @@ -113,7 +113,8 @@ class IR_API IrContext { const std::vector &trait_set, size_t attributes_num, const char **attributes_name, - void (*verify)(Operation *)); + void (*verify_sig)(Operation *), + void (*verify_region)(Operation *)); /// /// \brief Get registered operaiton infomation. diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index 528144437727f7..81cb3b4bcf2244 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -204,11 +204,17 @@ void IrPrinter::PrintValue(Value v) { os << ret->second; return; } - - std::string new_name = "%" + std::to_string(cur_var_number_); - cur_var_number_++; - aliases_[key] = new_name; - os << new_name; + if (v.isa()) { + std::string new_name = "%" + std::to_string(cur_result_number_); + cur_result_number_++; + aliases_[key] = new_name; + os << new_name; + } else { + std::string new_name = "%arg" + std::to_string(cur_block_argument_number_); + cur_block_argument_number_++; + aliases_[key] = new_name; + os << new_name; + } } void IrPrinter::PrintOpResult(Operation* op) { diff --git a/paddle/pir/core/ir_printer.h b/paddle/pir/core/ir_printer.h index 929da4fe332e1c..e4d821c01911bb 100644 --- a/paddle/pir/core/ir_printer.h +++ b/paddle/pir/core/ir_printer.h @@ -71,7 +71,8 @@ class IR_API IrPrinter : public BasicIrPrinter { void PrintOpReturnType(Operation* op); private: - size_t cur_var_number_{0}; + size_t cur_result_number_{0}; + size_t cur_block_argument_number_{0}; std::unordered_map aliases_; }; diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index 8e67a392c51cf3..f0710ff5ec6297 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -63,6 +63,10 @@ class IR_API OpBase { return operation()->attribute(name); } + void VerifySig() {} + + void VerifyRegion() {} + private: Operation *operation_; // Not owned }; @@ -162,14 +166,21 @@ class Op : public OpBase { class EmptyOp : public Op {}; return sizeof(ConcreteOp) == sizeof(EmptyOp); } - // Implementation of `VerifyInvariantsFn` OperationName hook. - static void VerifyInvariants(Operation *op) { + + // Implementation of `VerifySigInvariantsFn` OperationName hook. + static void VerifySigInvariants(Operation *op) { static_assert(HasNoDataMembers(), "Op class shouldn't define new data members"); - op->dyn_cast().Verify(); + op->dyn_cast().VerifySig(); (void)std::initializer_list{ 0, (VerifyTraitOrInterface::call(op), 0)...}; } + + static void VerifyRegionInvariants(Operation *op) { + static_assert(HasNoDataMembers(), + "Op class shouldn't define new data members"); + op->dyn_cast().VerifyRegion(); + } }; } // namespace pir diff --git a/paddle/pir/core/op_info.cc b/paddle/pir/core/op_info.cc index b018bec30448d4..499bfda0e69e7b 100644 --- a/paddle/pir/core/op_info.cc +++ b/paddle/pir/core/op_info.cc @@ -35,7 +35,18 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } -void OpInfo::Verify(Operation *operation) const { impl_->verify()(operation); } +void OpInfo::Verify(Operation *operation) const { + VerifySig(operation); + VerifyRegion(operation); +} + +void OpInfo::VerifySig(Operation *operation) const { + impl_->VerifySig()(operation); +} + +void OpInfo::VerifyRegion(Operation *operation) const { + impl_->VerifyRegion()(operation); +} void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr; diff --git a/paddle/pir/core/op_info.h b/paddle/pir/core/op_info.h index 23fc5bfe1b9ebc..a7416c146a90e5 100644 --- a/paddle/pir/core/op_info.h +++ b/paddle/pir/core/op_info.h @@ -54,6 +54,10 @@ class IR_API OpInfo { void Verify(Operation *) const; + void VerifySig(Operation *) const; + + void VerifyRegion(Operation *) const; + template bool HasTrait() const { return HasTrait(TypeId::get()); diff --git a/paddle/pir/core/op_info_impl.cc b/paddle/pir/core/op_info_impl.cc index 12245f12a652a5..33320f1d523670 100644 --- a/paddle/pir/core/op_info_impl.cc +++ b/paddle/pir/core/op_info_impl.cc @@ -24,7 +24,8 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], // NOLINT - VerifyPtr verify) { + VerifyPtr verify_sig, + VerifyPtr verify_region) { // (1) Malloc memory for interfaces, traits, opinfo_impl. size_t interfaces_num = interface_map.size(); size_t traits_num = trait_set.size(); @@ -59,7 +60,8 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, traits_num, attributes_num, attributes_name, - verify)); + verify_sig, + verify_region)); return op_info; } void OpInfoImpl::Destroy(OpInfo info) { diff --git a/paddle/pir/core/op_info_impl.h b/paddle/pir/core/op_info_impl.h index cc63a52d40064a..a08084682f1d00 100644 --- a/paddle/pir/core/op_info_impl.h +++ b/paddle/pir/core/op_info_impl.h @@ -42,14 +42,17 @@ class OpInfoImpl { const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], - VerifyPtr verify); + VerifyPtr verify_sig, + VerifyPtr verify_region); static void Destroy(OpInfo info); TypeId id() const { return op_id_; } Dialect *dialect() const { return dialect_; } - VerifyPtr verify() const { return verify_; } + VerifyPtr VerifySig() const { return verify_sig_; } + + VerifyPtr VerifyRegion() const { return verify_region_; } IrContext *ir_context() const; @@ -76,7 +79,8 @@ class OpInfoImpl { uint32_t num_traits, uint32_t num_attributes, const char **p_attributes, - VerifyPtr verify) + VerifyPtr verify_sig, + VerifyPtr verify_region) : dialect_(dialect), op_id_(op_id), op_name_(op_name), @@ -84,7 +88,8 @@ class OpInfoImpl { num_traits_(num_traits), num_attributes_(num_attributes), p_attributes_(p_attributes), - verify_(verify) {} + verify_sig_(verify_sig), + verify_region_(verify_region) {} void Destroy(); /// The dialect of this Op belong to. @@ -108,7 +113,9 @@ class OpInfoImpl { /// Attributes array address. const char **p_attributes_{nullptr}; - VerifyPtr verify_{nullptr}; + VerifyPtr verify_sig_{nullptr}; + + VerifyPtr verify_region_{nullptr}; }; } // namespace pir diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 6a13963c935870..0dedeafc9ae710 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -123,7 +123,7 @@ Operation *Operation::Create(const std::vector &inputs, // 0. Verify if (op_info) { - op_info.Verify(op); + op_info.VerifySig(op); } return op; } diff --git a/paddle/pir/core/parser/ir_parser.cc b/paddle/pir/core/parser/ir_parser.cc index 008dcdea6c7b10..ef881771ff4cfa 100644 --- a/paddle/pir/core/parser/ir_parser.cc +++ b/paddle/pir/core/parser/ir_parser.cc @@ -77,13 +77,13 @@ Type IrParser::ParseType() { return builder->int16_type(); } else if (type_val == "i32") { ConsumeToken(); - return Int32Type::get(ctx); + return builder->int32_type(); } else if (type_val == "i64") { ConsumeToken(); return Int64Type::get(ctx); } else if (type_val == "index") { ConsumeToken(); - return IndexType::get(ctx); + return builder->index_type(); } else if (type_val == "c64") { ConsumeToken(); return builder->complex64_type(); @@ -95,12 +95,15 @@ Type IrParser::ParseType() { ConsumeAToken("["); std::vector vec_type; Token vec_type_token = PeekToken(); + if (vec_type_token.val_ == "]") { + ConsumeAToken("]"); + } while (vec_type_token.val_ != "]") { Type cur_type = ParseType(); vec_type.push_back(cur_type); vec_type_token = ConsumeToken(); } - return VectorType::get(ctx, vec_type); + return builder->vec_type(vec_type); } else { IR_ENFORCE(type_val.find('.') != std::string::npos, "No function parsing " + type_val + " exists!" + @@ -138,12 +141,20 @@ Attribute IrParser::ParseAttribute() { ConsumeAToken("Float"); ConsumeAToken(")"); std::string val = ConsumeToken().val_; - return builder->float_attr(atof(val.c_str())); + if (val == "-") { + ConsumeAToken("inf"); + float neg_inf = -std::numeric_limits::infinity(); + return builder->float_attr(neg_inf); + } else if (val == "inf") { + float pos_inf = std::numeric_limits::infinity(); + return builder->float_attr(pos_inf); + } + return builder->float_attr(static_cast(atof(val.c_str()))); } else if (attribute_type == "Double") { ConsumeAToken("Double"); ConsumeAToken(")"); std::string val = ConsumeToken().val_; - return builder->double_attr(atof(val.c_str())); + return builder->double_attr(std::stod(val.c_str())); } else if (attribute_type == "Int32") { ConsumeAToken("Int32"); ConsumeAToken(")"); diff --git a/paddle/pir/core/parser/lexer.cc b/paddle/pir/core/parser/lexer.cc index 9bbfd7dbc804a7..8ab23e47576897 100644 --- a/paddle/pir/core/parser/lexer.cc +++ b/paddle/pir/core/parser/lexer.cc @@ -35,16 +35,23 @@ Token Lexer::ConsumeToken() { Token Lexer::PeekToken() { auto pos = is.tellg(); + size_t cache_line = line; + size_t cache_column = column; + auto token = ConsumeToken(); + if (is.eof()) { is.clear(); } is.seekg(pos); + line = cache_line; + column = cache_column; + return token; } char Lexer::GetChar() { - char c = is.get(); + char c = static_cast(is.get()); if (c == '\n') { line++; column = 1; @@ -59,13 +66,14 @@ size_t Lexer::GetColumn() { return column; } size_t Lexer::GetLine() { return line; } void Lexer::SkipWhitespace() { - while (IsSpace(is.peek())) { + while (IsSpace(static_cast(is.peek()))) { GetChar(); } } std::unique_ptr Lexer::LexIdentifer() { - if ((!isalpha(is.peek()) && is.peek() != '_') || IsEndTag(is.peek())) { + if ((!isalpha(is.peek()) && is.peek() != '_') || + IsEndTag(static_cast(is.peek()))) { return nullptr; } std::string token_identifier = ""; @@ -114,7 +122,7 @@ std::unique_ptr Lexer::LexNumberOrArraow() { } std::unique_ptr Lexer::LexEndTagOrNullVal() { - if (!IsEndTag(is.peek())) { + if (!IsEndTag(static_cast(is.peek()))) { return nullptr; } std::string token_end = ""; diff --git a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc index 7166af2ece6363..ed36c0c81cca6a 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_dialect.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_dialect.cc @@ -15,6 +15,6 @@ #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" namespace pir { -void ControlFlowDialect::initialize() { RegisterOps(); } +void ControlFlowDialect::initialize() { RegisterOps(); } } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::ControlFlowDialect) diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.cc b/paddle/pir/dialect/control_flow/ir/cf_ops.cc index 69dce41e62badb..7981a6ab963965 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.cc +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.cc @@ -24,4 +24,3 @@ void YieldOp::Build(Builder &builder, } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::YieldOp) -IR_DEFINE_EXPLICIT_TYPE_ID(pir::CondYieldOp) diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.h b/paddle/pir/dialect/control_flow/ir/cf_ops.h index 898f954e09d5f5..7d669c0b648ea0 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.h +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.h @@ -28,33 +28,8 @@ class IR_API YieldOp : public Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT const std::vector &Value); - void Verify() {} + void VerifySig() {} }; - -class IR_API CondYieldOp : public Op { - public: - using Op::Op; - static const char *name() { return "cf.cond_yield"; } - static constexpr uint32_t attributes_num = 0; - static constexpr const char **attributes_name = nullptr; - - template - static void Build(Builder &builder, // NOLINT - OperationArgument &argument, // NOLINT - Value cond, - const ValueContainer &inputs); - void Verify() {} -}; - -template -void CondYieldOp::Build(Builder &builder, // NOLINT - OperationArgument &argument, // NOLINT - Value cond, - const ValueContainer &inputs) { - argument.AddInput(cond); - argument.AddInputs(inputs); -} } // namespace pir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::YieldOp); -IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::CondYieldOp); diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index c8ec2df0123412..c838624d2566df 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -71,7 +71,7 @@ class IR_API SymbolicDim : public Op { return "kSymbolicDimAttr"; } - void Verify() {} + void VerifySig() {} }; class IR_API DimOp : public Op { @@ -89,7 +89,7 @@ class IR_API DimOp : public Op { const std::string getName(); void setName(std::string attrValue); OpResult out() { return result(0); } - void Verify() {} + void VerifySig() {} }; class IR_API TieProductEqualOp : public Op { @@ -111,7 +111,7 @@ class IR_API TieProductEqualOp : public Op { const std::vector &rhs); std::vector lhs(); std::vector rhs(); - void Verify() {} + void VerifySig() {} }; class IR_API TieShapeOp : public Op { @@ -132,7 +132,7 @@ class IR_API TieShapeOp : public Op { const std::vector &dims); Value value(); std::vector dims(); - void Verify() {} + void VerifySig() {} }; class IR_API FuncOp : public Op { @@ -147,7 +147,7 @@ class IR_API FuncOp : public Op { OperationArgument &argument); // NOLINT void Print(IrPrinter &printer); // NOLINT Block *block(); - void Verify() {} + void VerifySig() {} }; class IR_API TensorDimOp : public Op { @@ -169,7 +169,7 @@ class IR_API TensorDimOp : public Op { Value index(); Value source(); OpResult out() { return result(0); } - void Verify() {} + void VerifySig() {} }; } // namespace pir::dialect diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 2b130f73f6d077..d746831835ed89 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -25,7 +25,12 @@ bool ShapeAnalysis::IsSameNumElements(Value lhs, Value rhs) { if (!lhs_type || !rhs_type || !lhs_type.HasRank() || !rhs_type.HasRank()) return false; - return IsProductEqual(lhs, 0, lhs_type.GetRank(), rhs, 0, rhs_type.GetRank()); + return IsProductEqual(lhs, + 0, + static_cast(lhs_type.GetRank()), + rhs, + 0, + static_cast(rhs_type.GetRank())); } bool ShapeAnalysis::IsProductEqual( diff --git a/paddle/pir/pass/ir_printing.cc b/paddle/pir/pass/ir_printing.cc index 6171b71c090fcf..901c8bdd89da78 100644 --- a/paddle/pir/pass/ir_printing.cc +++ b/paddle/pir/pass/ir_printing.cc @@ -31,12 +31,8 @@ void PrintIR(Operation *op, bool print_module, std::ostream &os) { return; } - // Find the top-level operation. - auto *top_op = op; - while (auto *parent_op = top_op->GetParentOp()) { - top_op = parent_op; - } - top_op->Print(os); + auto *program = op->GetParentProgram(); + program->Print(os); } } // namespace diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc index 00d6cb2f4d3064..ff75f86d6da55a 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc @@ -131,6 +131,7 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter { for (uint32_t i = 0; i < op->num_operands(); ++i) { AddOperandToWorklist(op->operand_source(i)); } + if (op->num_regions() == 0) { RemoveFromWorklist(op); } else { diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 5644ccc5adcb8c..15eb574482232d 100644 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -932,24 +932,69 @@ set -ex fi } +function check_run_sot_ci() { + set +x + # use "git commit -m 'message, test=sot'" to force ci to run + COMMIT_RUN_CI=$(git log -1 --pretty=format:"%s" | grep -w "test=sot" || true) + # check pr title + TITLE_RUN_CI=$(curl -s https://github.com/PaddlePaddle/Paddle/pull/${GIT_PR_ID} | grep "" | grep -i "sot" || true) + if [[ ${COMMIT_RUN_CI} || ${TITLE_RUN_CI} ]]; then + set -x + return + fi + + # git diff + SOT_FILE_LIST=( + paddle/fluid/operators/run_program_op.h + paddle/fluid/operators/run_program_op.cu + paddle/fluid/operators/run_program_op.cc + paddle/fluid/eager/to_static + paddle/fluid/pybind/ + python/ + test/sot + ) + + run_sot_ut="OFF" + for change_file in $(git diff --name-only upstream/develop); + do + for sot_file in ${SOT_FILE_LIST[@]}; + do + if [[ ${change_file} =~ ^"${sot_file}".* ]]; then + echo "Detect change about SOT: " + echo "Changes related to the sot code were detected: " ${change_file} + run_sot_ut="ON" + break + fi + done + if [[ "ON" == ${run_sot_ut} ]]; then + break + fi + done + + if [[ "OFF" == ${run_sot_ut} ]]; then + echo "No SOT-related changes were found" + echo "Skip SOT UT CI" + exit 0 + fi + set -x +} + function run_sot_test() { - PADDLE_SOT_ROOT=$1 - PY_VERSION=$2 + PY_VERSION=$1 PYTHON_WITH_SPECIFY_VERSION=python$PY_VERSION PY_VERSION_NO_DOT=`echo $PY_VERSION | sed 's/\.//g'` export STRICT_MODE=1 export COST_MODEL=False export MIN_GRAPH_SIZE=0 + export SOT_LOG_LEVEL=0 # Install PaddlePaddle $PYTHON_WITH_SPECIFY_VERSION -m pip install ${PADDLE_ROOT}/dist/paddlepaddle-0.0.0-cp${PY_VERSION_NO_DOT}-cp${PY_VERSION_NO_DOT}-linux_x86_64.whl # Install PaddleSOT - cd $PADDLE_SOT_ROOT - $PYTHON_WITH_SPECIFY_VERSION -m pip install -e . + cd $PADDLE_ROOT/test/sot/ # Run unittest - cd tests failed_tests=() for file in ./test_*.py; do @@ -2394,6 +2439,11 @@ set -x ut_endTime_s=`date +%s` echo "CINN testCase Time: $[ $ut_endTime_s - $ut_startTime_s ]s" if [[ "$EXIT_CODE" != "0" ]]; then + rm -f $tmp_dir/* + echo "Summary Failed Tests... " + echo "========================================" + echo "The following tests FAILED: " + echo "${failuretest}" | sort -u exit 8; fi fi @@ -3329,6 +3379,7 @@ function build_pr_and_develop() { mkdir ${PADDLE_ROOT}/build/dev_whl && wget -q -P ${PADDLE_ROOT}/build/dev_whl ${dev_url} cp ${PADDLE_ROOT}/build/dev_whl/paddlepaddle_gpu-0.0.0-cp310-cp310-linux_x86_64.whl ${PADDLE_ROOT}/build/python/dist else + tar --use-compress-program="pigz -1" -cpPf build.tar.gz ${PADDLE_ROOT}/build if [[ ${cmake_change} ]];then rm -rf ${PADDLE_ROOT}/build/Makefile ${PADDLE_ROOT}/build/CMakeCache.txt ${PADDLE_ROOT}/build/build.ninja rm -rf ${PADDLE_ROOT}/build/third_party @@ -3337,6 +3388,15 @@ function build_pr_and_develop() { git checkout -b develop_base_pr upstream/$BRANCH git submodule update --init run_setup ${PYTHON_ABI:-""} "rerun-cmake bdist_wheel" ${parallel_number} + #NOTE(risemeup1):remove build directory of develop branch to avoid conflict with pr branch,we only need whl package of develop branch + rm -rf ${PADDLE_ROOT}/build + if [ -e "${PADDLE_ROOT}/build.tar.gz" ]; then + tar --use-compress-program="pigz -1" -xpf build.tar.gz + else + echo "build.tar.gz of pr branch not exist" + exit 123 + fi + if [ ! -d "${PADDLE_ROOT}/build/python/dist/" ]; then mkdir ${PADDLE_ROOT}/build/python/dist/ fi @@ -4122,15 +4182,14 @@ function main() { run_linux_cpu_test ${PYTHON_ABI:-""} ${PROC_RUN:-1} ;; cicheck_sot) + check_run_sot_ci export WITH_SHARED_PHI=ON - PADDLE_SOT_ROOT=${PADDLE_ROOT}/sot - git clone https://github.com/PaddlePaddle/PaddleSOT.git ${PADDLE_SOT_ROOT} PYTHON_VERSIONS=(3.8 3.9 3.10 3.11) for PY_VERSION in ${PYTHON_VERSIONS[@]}; do ln -sf $(which python${PY_VERSION}) /usr/local/bin/python ln -sf $(which pip${PY_VERSION}) /usr/local/bin/pip run_setup ${PYTHON_ABI:-""} bdist_wheel ${parallel_number} - run_sot_test $PADDLE_SOT_ROOT $PY_VERSION + run_sot_test $PY_VERSION rm -rf ${PADDLE_ROOT}/build/CMakeCache.txt done ;; diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index 667045eaebf97c..8e615f7a6cb114 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "gtest/gtest.h" -#include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/framework/init_default_kernel_signature_map.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/platform/init.h" #include "paddle/phi/core/flags.h" diff --git a/paddle/utils/flags_native_test.cc b/paddle/utils/flags_native_test.cc index 26ef8c12c18753..397072bf2914b7 100644 --- a/paddle/utils/flags_native_test.cc +++ b/paddle/utils/flags_native_test.cc @@ -52,8 +52,8 @@ TEST(flags_native_test, ParseCommandLineFlags) { std::string commandline = "test --paddle_test_int32=3 --paddle_test_uint32=\"4\" " "--paddle_test_string \"modified string\""; - int argc; - char** argv; + int argc = 0; + char** argv = nullptr; SplitCommandlineArg(commandline, &argc, &argv); // Parse commandline flags and check diff --git a/pyproject.toml b/pyproject.toml index f50f5a363b2c0b..3e8da7d18ed6fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,15 +104,6 @@ ignore = [ # Ignore unnecessary lambda in dy2st unittest test_lambda "test/dygraph_to_static/test_lambda.py" = ["PLC3002"] -# Temporarily ignored -"python/paddle/base/**" = [ - "UP030", - "C405", - "B019", # Confirmation required - "C416", - "F821", -] - # B017 "test/auto_parallel/spmd_rules/test_reshape_rule.py" = ["B017"] "test/dygraph_to_static/test_assert.py" = ["B017"] diff --git a/python/cinn/compiler/compiler.py b/python/cinn/compiler/compiler.py index 12f1ffb79d6407..064b97c31f243b 100644 --- a/python/cinn/compiler/compiler.py +++ b/python/cinn/compiler/compiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import cinn from ..runtime import CinnLowerLevelIrJit from .compute_code_generator import ComputeCodeGenerator @@ -31,6 +32,13 @@ def ast_to_llir(fn, inputs_signature): return llir_schedule_generator.parse() +def llir_to_runtime_module(llir_func, target, function_name, arg_names): + cinn_builder = cinn.lang.Module.Builder(function_name, target) + cinn_builder.add_function(llir_func) + llir_module = cinn_builder.build() + return cinn.runtime.Module(llir_module, target, function_name, arg_names) + + def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs): if isinstance(fn, CinnLowerLevelIrJit): llir_func = ast_to_llir(fn, jit_inputs_signature) @@ -39,3 +47,9 @@ def compile(fn, just_convert=False, jit_inputs_signature=[], **kwargs): if just_convert: return llir_func + + rt_module = llir_to_runtime_module( + llir_func, kwargs["target"], fn.__name__, kwargs["arg_names"] + ) + + return rt_module diff --git a/python/cinn/runtime/__init__.py b/python/cinn/runtime/__init__.py index 70753e812e6b63..244567bd855c22 100644 --- a/python/cinn/runtime/__init__.py +++ b/python/cinn/runtime/__init__.py @@ -68,5 +68,6 @@ ) from .cinn_jit import CinnLowerLevelIrJit +from .module import Module -__all__ = ["CinnLowerLevelIrJit"] +__all__ = ["CinnLowerLevelIrJit", "Module"] diff --git a/python/cinn/runtime/data_array.py b/python/cinn/runtime/data_array.py index 4e7c58eced3358..e422005622cac7 100644 --- a/python/cinn/runtime/data_array.py +++ b/python/cinn/runtime/data_array.py @@ -36,32 +36,39 @@ def to_numpy(self): """ Convert DataArray to numpy array """ - cinn_dtype_to_np_dtype = { + np_dtype = "unk" + if self.dtype.is_bfloat16(): # numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle - BFloat16(): "uint16", - BFloat16(): "bfloat16", - Float16(): "float16", - Float(32): "float32", - Float(64): "float64", - Int(8): "int8", - Int(16): "int16", - Int(32): "int32", - Int(64): "int64", - UInt(8): "uint8", - # numpy has no 'bfloat16', we use uint16 to hold bfloat16 data, same to Paddle - # "UInt(16): uint16" - UInt(32): "uint32", - UInt(64): "uint64", - Bool(): "bool", - } - for cinn_dtype, np_dtype in cinn_dtype_to_np_dtype.items(): - if isinstance(self.dtype, cinn_dtype): - np_arr = np.empty(self.shape, np_dtype) - assert np_arr.flags["C_CONTIGUOUS"] - self.data.copy_to(np_arr) - return np_arr + np_dtype = "uint16" + elif self.dtype.is_float16(): + np_dtype = "float16" + elif self.dtype.is_float(32, common.Type.specific_type_t.UNK): + np_dtype = "float32" + elif self.dtype.is_float(64, common.Type.specific_type_t.UNK): + np_dtype = "float64" + elif self.dtype.is_int(8): + np_dtype = "int8" + elif self.dtype.is_int(16): + np_dtype = "int16" + elif self.dtype.is_int(32): + np_dtype = "int32" + elif self.dtype.is_int(64): + np_dtype = "int64" + elif self.dtype.is_uint(8): + np_dtype = "uint8" + elif self.dtype.is_uint(32): + np_dtype = "uint32" + elif self.dtype.is_uint(64): + np_dtype = "uint64" + elif self.dtype.is_bool(): + np_dtype = "bool" + else: + raise TypeError(f"no support {self.dtype} in CINN") - raise TypeError(f"no support {self._dtype} in CINN") + np_arr = np.empty(self.shape, np_dtype) + assert np_arr.flags["C_CONTIGUOUS"] + self.data.copy_to(np_arr) + return np_arr @staticmethod def from_numpy(np_array, target=common.DefaultHostTarget()): diff --git a/python/cinn/runtime/module.py b/python/cinn/runtime/module.py new file mode 100644 index 00000000000000..24a31691015944 --- /dev/null +++ b/python/cinn/runtime/module.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cinn +from cinn import framework +from cinn.backends import Compiler + + +class Module: + def __init__(self, llir_module, target, fn_name, arg_names): + self.arg_names = arg_names + self.fn_name = fn_name + self.compiler = Compiler.create(target) + self.compiler.build(llir_module) + self._instruction = framework.Instruction( + target, None, [], arg_names, fn_name + ) + + def __call__(self, *args): + name2pod = {} + for i, name in enumerate(self.arg_names): + if isinstance(args[i], cinn.runtime.data_array.DataArray): + name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i].data) + else: + name2pod[name] = cinn.runtime.cinn_pod_value_t(args[i]) + + self._instruction.run(self.compiler, self.fn_name, name2pod) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5b3e806c3f947b..11a2d07d2096dd 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -407,6 +407,8 @@ i1e, polygamma, polygamma_, + hypot, + hypot_, ) from .tensor.random import ( @@ -904,4 +906,6 @@ 'i1e', 'polygamma', 'polygamma_', + 'hypot', + 'hypot_', ] diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index d39666f6f1c532..d612b93bd1cf31 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -198,9 +198,10 @@ def set_excluded_layers(models, excluded_layers): include_self=True ): layer._cast_to_low_precison = False + excluded_layers_types = tuple(excluded_layers_types) for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): - if type(layer) in excluded_layers_types: + if isinstance(layer, excluded_layers_types): layer._cast_to_low_precison = False @@ -358,37 +359,38 @@ def amp_guard( % tracer._expected_place ) enable = False - # For xpu: - if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): - warnings.warn('XPUPlace only support float16 amp.') - enable = False - # For custom device: - if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): - warnings.warn('CustomPlace only support float16 amp.') - enable = False - # For gpu float16: Compute Capability should >= 7. - # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. - if tracer._expected_place.is_gpu_place(): - if (dtype == 'float16') and not _is_gpu_float16_supported(): - prop = paddle.device.cuda.get_device_capability() - warnings.warn( - "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." - % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) - ) + if enable: + # For xpu: + if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): + warnings.warn('XPUPlace only support float16 amp.') enable = False - elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): - prop = paddle.device.cuda.get_device_capability() - cuda_version = paddle.version.cuda() - warnings.warn( - "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." - % ( - paddle.device.cuda.get_device_name(), - prop[0], - prop[1], - cuda_version, - ) - ) + # For custom device: + if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'): + warnings.warn('CustomPlace only support float16 amp.') enable = False + # For gpu float16: Compute Capability should >= 7. + # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. + if tracer._expected_place.is_gpu_place(): + if (dtype == 'float16') and not _is_gpu_float16_supported(): + prop = paddle.device.cuda.get_device_capability() + warnings.warn( + "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." + % (paddle.device.cuda.get_device_name(), prop[0], prop[1]) + ) + enable = False + elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + warnings.warn( + "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." + % ( + paddle.device.cuda.get_device_name(), + prop[0], + prop[1], + cuda_version, + ) + ) + enable = False amp_dtype = dtype amp_global_state().amp_dtype = amp_dtype diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 97a315c1010566..ad5a7cc02aef9e 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -13,6 +13,7 @@ # limitations under the License. import collections +import logging from collections.abc import Sequence import paddle.pir @@ -363,8 +364,8 @@ def make_output_grad(op): for i, value in enumerate(op.results()): if ( value in state.value_to_valuegrad - and len(state.value_to_valuegrad[value]) - ) > 1: + and len(state.value_to_valuegrad[value]) > 1 + ): # one value is input of more than one fwd_op, # so more than one bwd_op create input_grad, # need add sum op to accumulate gradient @@ -556,7 +557,7 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): if state.value_to_valuegrad[item] != []: outputs_set.add(state.value_to_valuegrad[item][0][0]) else: - raise ValueError("input privided by inputs has no use") + logging.warning("input privided by inputs has no use") inputs_set = set() for output in outputs: diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 3e057dcb12cf7e..876db0abc3aa70 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -2067,7 +2067,7 @@ def append_backward( # not support double grad in control flow sub-block now. if not is_in_control_flow: if program._appending_grad_times > 1: - input_grad_names_set = set([_append_grad_suffix_(loss.name)]) + input_grad_names_set = {_append_grad_suffix_(loss.name)} # TODO: support _append_backward_ops_with_checkpoints_ in # sub-block (control flow) diff --git a/python/paddle/base/data_feeder.py b/python/paddle/base/data_feeder.py index 2449f456fdc66b..81c6f32a893ca4 100644 --- a/python/paddle/base/data_feeder.py +++ b/python/paddle/base/data_feeder.py @@ -17,6 +17,7 @@ import numpy as np from ..pir import OpResult +from ..pir.core import ParameterMeta from . import core from .framework import ( Variable, @@ -147,7 +148,9 @@ def check_variable_and_dtype( input, input_name, expected_dtype, op_name, extra_message='' ): if in_pir_mode(): - check_type(input, input_name, OpResult, op_name, extra_message) + check_type( + input, input_name, (OpResult, ParameterMeta), op_name, extra_message + ) else: check_type(input, input_name, Variable, op_name, extra_message) check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message) diff --git a/python/paddle/base/dygraph/base.py b/python/paddle/base/dygraph/base.py index 52055fc8f55e0a..5fad89935d4c7e 100644 --- a/python/paddle/base/dygraph/base.py +++ b/python/paddle/base/dygraph/base.py @@ -174,12 +174,14 @@ def enabled(): Examples: .. code-block:: python - import paddle.base as base - - base.enable_dygraph() # Now we are in dygragh mode - print(base.dygraph.enabled()) # True - base.disable_dygraph() - print(base.dygraph.enabled()) # False + >>> import paddle.base as base + + >>> base.enable_dygraph() # Now we are in dygragh mode + >>> print(base.dygraph.enabled()) + True + >>> base.disable_dygraph() + >>> print(base.dygraph.enabled()) + False """ # TODO(jiabin): Make this check as in_dygraph_mode when we support default eager mode. return framework.in_dygraph_mode() @@ -204,14 +206,17 @@ def enable_dygraph(place=None): Examples: .. code-block:: python - import paddle - print(paddle.in_dynamic_mode()) # True, dynamic mode is turn ON by default since paddle 2.0.0 + >>> import paddle + >>> print(paddle.in_dynamic_mode()) + True - paddle.enable_static() - print(paddle.in_dynamic_mode()) # False, Now we are in static graph mode + >>> paddle.enable_static() + >>> print(paddle.in_dynamic_mode()) + False - paddle.disable_static() - print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode + >>> paddle.disable_static() + >>> print(paddle.in_dynamic_mode()) + True """ global global_var @@ -239,14 +244,17 @@ def disable_dygraph(): Examples: .. code-block:: python - import paddle - print(paddle.in_dynamic_mode()) # True, dynamic mode is turn ON by default since paddle 2.0.0 + >>> import paddle + >>> print(paddle.in_dynamic_mode()) + True - paddle.enable_static() - print(paddle.in_dynamic_mode()) # False, Now we are in static graph mode + >>> paddle.enable_static() + >>> print(paddle.in_dynamic_mode()) + False - paddle.disable_static() - print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode + >>> paddle.disable_static() + >>> print(paddle.in_dynamic_mode()) + True """ global global_var @@ -280,40 +288,40 @@ def no_grad(func=None): Examples: - .. code-block:: python - - import numpy as np - import paddle.base as base - - # use as generator - - data = np.array([[2, 3], [4, 5]]).astype('float32') - with base.dygraph.guard(): - l0 = base.Linear(2, 2) # l0.weight.gradient() is None - l1 = base.Linear(2, 2) - with base.dygraph.no_grad(): - # l1.weight.stop_gradient is False - tmp = l1.weight * 2 # tmp.stop_gradient is True - x = base.dygraph.to_variable(data) - y = l0(x) + tmp - o = l1(y) - o.backward() - print(tmp.gradient() is None) # True - print(l0.weight.gradient() is None) # False - - # use as decorator - - @base.dygraph.no_grad - def test_layer(): - with base.dygraph.guard(): - inp = np.ones([3, 1024], dtype='float32') - t = base.dygraph.base.to_variable(inp) - linear1 = base.Linear(1024, 4, bias_attr=False) - linear2 = base.Linear(4, 4) - ret = linear1(t) - dy_ret = linear2(ret) - - test_layer() + .. code-block:: python + + >>> import numpy as np + >>> import paddle.base as base + + >>> # use as generator + + >>> data = np.array([[2, 3], [4, 5]]).astype('float32') + >>> with base.dygraph.guard(): + ... l0 = paddle.nn.Linear(2, 2) # l0.weight.gradient() is None + ... l1 = paddle.nn.Linear(2, 2) + ... with base.dygraph.no_grad(): + ... # l1.weight.stop_gradient is False + ... tmp = l1.weight * 2 # tmp.stop_gradient is True + ... x = base.dygraph.to_variable(data) + ... y = l0(x) + tmp + ... o = l1(y) + ... o.backward() + ... print(tmp.gradient() is None) + ... print(l0.weight.gradient() is None) + True + False + + >>> @base.dygraph.no_grad + >>> def test_layer(): + ... with base.dygraph.guard(): + ... inp = np.ones([3, 1024], dtype='float32') + ... t = base.dygraph.base.to_variable(inp) + ... linear1 = paddle.nn.Linear(1024, 4, bias_attr=False) + ... linear2 = paddle.nn.Linear(4, 4) + ... ret = linear1(t) + ... dy_ret = linear2(ret) + ... + >>> test_layer() """ if in_to_static_mode(): @@ -373,16 +381,19 @@ def is_grad_enabled(): Examples: .. code-block:: python - import paddle + >>> import paddle - # Dygraph gradient calculation mode is enabled by default. - paddle.is_grad_enabled() # True + >>> # Dygraph gradient calculation mode is enabled by default. + >>> paddle.is_grad_enabled() + True - with paddle.set_grad_enabled(False): - paddle.is_grad_enabled() # False + >>> with paddle.set_grad_enabled(False): + ... paddle.is_grad_enabled() + False - paddle.enable_static() - paddle.is_grad_enabled() # False + >>> paddle.enable_static() + >>> paddle.is_grad_enabled() + False """ tracer = framework._dygraph_tracer() return tracer._has_grad if tracer else False @@ -407,20 +418,23 @@ class set_grad_enabled(_DecoratorContextManager): Examples: .. code-block:: python - import paddle - x = paddle.to_tensor([1.], stop_gradient=False) - is_train = False - with paddle.set_grad_enabled(is_train): - y = x * 2 - assert(y.stop_gradient == True) - - paddle.set_grad_enabled(True) - y = x * 2 - assert(y.stop_gradient == False) - - paddle.set_grad_enabled(False) - y = x * 2 - assert(y.stop_gradient == True) + >>> import paddle + >>> x = paddle.to_tensor([1.], stop_gradient=False) + >>> is_train = False + >>> with paddle.set_grad_enabled(is_train): + ... y = x * 2 + >>> print(y.stop_gradient) + True + + >>> paddle.set_grad_enabled(True) + >>> y = x * 2 + >>> print(y.stop_gradient) + False + + >>> paddle.set_grad_enabled(False) + >>> y = x * 2 + >>> print(y.stop_gradient) + True """ def __init__(self, mode): @@ -450,38 +464,40 @@ class no_grad_(_DecoratorContextManager): Examples: - .. code-block:: python - - import numpy as np - import paddle - - # use as generator - - data = np.array([[2, 3], [4, 5]]).astype('float32') - l0 = paddle.nn.Linear(2, 2) # l0.weight.gradient() is None - l1 = paddle.nn.Linear(2, 2) - with paddle.no_grad(): - # l1.weight.stop_gradient is False - tmp = l1.weight * 2 # tmp.stop_gradient is True - x = paddle.to_tensor(data) - y = l0(x) + tmp - o = l1(y) - o.backward() - print(tmp.gradient() is None) # True - print(l0.weight.gradient() is None) # False - - # use as decorator - - @paddle.no_grad() - def test_layer(): - inp = np.ones([3, 1024], dtype='float32') - t = paddle.to_tensor(inp) - linear1 = paddle.nn.Linear(1024, 4, bias_attr=False) - linear2 = paddle.nn.Linear(4, 4) - ret = linear1(t) - dy_ret = linear2(ret) - - test_layer() + .. code-block:: python + + >>> import numpy as np + >>> import paddle + + >>> # use as generator + + >>> data = np.array([[2, 3], [4, 5]]).astype('float32') + >>> l0 = paddle.nn.Linear(2, 2) # l0.weight.gradient() is None + >>> l1 = paddle.nn.Linear(2, 2) + >>> with paddle.no_grad(): + ... # l1.weight.stop_gradient is False + ... tmp = l1.weight * 2 # tmp.stop_gradient is True + >>> x = paddle.to_tensor(data) + >>> y = l0(x) + tmp + >>> o = l1(y) + >>> o.backward() + >>> print(tmp.gradient() is None) + True + >>> print(l0.weight.gradient() is None) + False + + >>> # use as decorator + + >>> @paddle.no_grad() + >>> def test_layer(): + ... inp = np.ones([3, 1024], dtype='float32') + ... t = paddle.to_tensor(inp) + ... linear1 = paddle.nn.Linear(1024, 4, bias_attr=False) + ... linear2 = paddle.nn.Linear(4, 4) + ... ret = linear1(t) + ... dy_ret = linear2(ret) + ... + >>> test_layer() """ def __enter__(self): @@ -506,30 +522,30 @@ class enable_grad(_DecoratorContextManager): Examples: - .. code-block:: python - - import paddle - - # use as generator + .. code-block:: python - x = paddle.to_tensor([1.], stop_gradient=False) - with paddle.no_grad(): - with paddle.enable_grad(): - y = x * 2 - assert(y.stop_gradient == False) - y.backward() - assert(x.grad is not None) + >>> import paddle - # use as decorator + >>> # use as generator - @paddle.enable_grad() - def double(x): - return x * 2 + >>> x = paddle.to_tensor([1.], stop_gradient=False) + >>> with paddle.no_grad(): + ... with paddle.enable_grad(): + ... y = x * 2 + >>> assert(y.stop_gradient == False) + >>> y.backward() + >>> assert(x.grad is not None) - with paddle.no_grad(): - z = double(x) + >>> # use as decorator - assert(z.stop_gradient == False) + >>> @paddle.enable_grad() + >>> def double(x): + ... return x * 2 + ... + >>> with paddle.no_grad(): + ... z = double(x) + ... + >>> assert(z.stop_gradient == False) """ def __enter__(self): @@ -558,19 +574,19 @@ def guard(place=None): Examples: - .. code-block:: python - - import numpy as np - import paddle.base as base - - with base.dygraph.guard(): - inp = np.ones([3, 1024], dtype='float32') - t = base.dygraph.base.to_variable(inp) - linear1 = base.Linear(1024, 4, bias_attr=False) - linear2 = base.Linear(4, 4) - ret = linear1(t) - dy_ret = linear2(ret) + .. code-block:: python + >>> import numpy as np + >>> import paddle.base as base + + >>> with base.dygraph.guard(): + ... inp = np.ones([3, 1024], dtype='float32') + ... t = base.dygraph.base.to_variable(inp) + ... linear1 = paddle.nn.Linear(1024, 4, bias_attr=False) + ... linear2 = paddle.nn.Linear(4, 4) + ... ret = linear1(t) + ... dy_ret = linear2(ret) + ... """ train = framework.Program() startup = framework.Program() @@ -651,79 +667,85 @@ def grad( .. code-block:: python :name: code-example-1 - import paddle - - def test_dygraph_grad(create_graph): - x = paddle.ones(shape=[1], dtype='float32') - x.stop_gradient = False - y = x * x - - # Since y = x * x, dx = 2 * x - dx = paddle.grad( - outputs=[y], - inputs=[x], - create_graph=create_graph, - retain_graph=True)[0] - - z = y + dx - - # If create_graph = False, the gradient of dx - # would not be backpropagated. Therefore, - # z = x * x + dx, and x.gradient() = 2 * x = 2.0 - - # If create_graph = True, the gradient of dx - # would be backpropagated. Therefore, - # z = x * x + dx = x * x + 2 * x, and - # x.gradient() = 2 * x + 2 = 4.0 - - z.backward() - return x.gradient() - - print(test_dygraph_grad(create_graph=False)) # [2.] - print(test_dygraph_grad(create_graph=True)) # [4.] + >>> import paddle + + >>> def test_dygraph_grad(create_graph): + ... x = paddle.ones(shape=[1], dtype='float32') + ... x.stop_gradient = False + ... y = x * x + ... + ... # Since y = x * x, dx = 2 * x + ... dx = paddle.grad( + ... outputs=[y], + ... inputs=[x], + ... create_graph=create_graph, + ... retain_graph=True)[0] + ... + ... z = y + dx + ... + ... # If create_graph = False, the gradient of dx + ... # would not be backpropagated. Therefore, + ... # z = x * x + dx, and x.gradient() = 2 * x = 2.0 + ... + ... # If create_graph = True, the gradient of dx + ... # would be backpropagated. Therefore, + ... # z = x * x + dx = x * x + 2 * x, and + ... # x.gradient() = 2 * x + 2 = 4.0 + ... + ... z.backward() + ... return x.gradient() + ... + >>> print(test_dygraph_grad(create_graph=False)) + [2.] + >>> print(test_dygraph_grad(create_graph=True)) + [4.] .. code-block:: python :name: code-example-2 - import paddle - - def test_dygraph_grad(grad_outputs=None): - x = paddle.to_tensor(2.0) - x.stop_gradient = False - - y1 = x * x - y2 = x * 3 - - # If grad_outputs=None, dy1 = [1], dy2 = [1]. - # If grad_outputs=[g1, g2], then: - # - dy1 = [1] if g1 is None else g1 - # - dy2 = [1] if g2 is None else g2 - - # Since y1 = x * x, dx = 2 * x * dy1. - # Since y2 = x * 3, dx = 3 * dy2. - # Therefore, the final result would be: - # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2. - - dx = paddle.grad( - outputs=[y1, y2], - inputs=[x], - grad_outputs=grad_outputs)[0] - - return dx.numpy() - - grad_value = paddle.to_tensor(4.0) - # dy1 = [1], dy2 = [1] - print(test_dygraph_grad(None)) # [7.] - - # dy1 = [1], dy2 = [4] - print(test_dygraph_grad([None, grad_value])) # [16.] - - # dy1 = [4], dy2 = [1] - print(test_dygraph_grad([grad_value, None])) # [19.] - - # dy1 = [3], dy2 = [4] - grad_y1 = paddle.to_tensor(3.0) - print(test_dygraph_grad([grad_y1, grad_value])) # [24.] + >>> import paddle + + >>> def test_dygraph_grad(grad_outputs=None): + ... x = paddle.to_tensor(2.0) + ... x.stop_gradient = False + ... + ... y1 = x * x + ... y2 = x * 3 + ... + ... # If grad_outputs=None, dy1 = [1], dy2 = [1]. + ... # If grad_outputs=[g1, g2], then: + ... # - dy1 = [1] if g1 is None else g1 + ... # - dy2 = [1] if g2 is None else g2 + ... + ... # Since y1 = x * x, dx = 2 * x * dy1. + ... # Since y2 = x * 3, dx = 3 * dy2. + ... # Therefore, the final result would be: + ... # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2. + ... + ... dx = paddle.grad( + ... outputs=[y1, y2], + ... inputs=[x], + ... grad_outputs=grad_outputs)[0] + ... + ... return dx.numpy() + ... + >>> grad_value = paddle.to_tensor(4.0) + >>> # dy1 = [1], dy2 = [1] + >>> print(test_dygraph_grad(None)) + 7. + + >>> # dy1 = [1], dy2 = [4] + >>> print(test_dygraph_grad([None, grad_value])) + 16. + + >>> # dy1 = [4], dy2 = [1] + >>> print(test_dygraph_grad([grad_value, None])) + 19. + + >>> # dy1 = [3], dy2 = [4] + >>> grad_y1 = paddle.to_tensor(3.0) + >>> print(test_dygraph_grad([grad_y1, grad_value])) + 24. ''' if in_to_static_mode(): # In dy2static context, we call static interface `gradients` @@ -778,8 +800,6 @@ def check_in_out(in_out_list, name): no_grad_vars = [] elif isinstance(no_grad_vars, core.eager.Tensor): no_grad_vars = [no_grad_vars] - elif isinstance(no_grad_vars, core.eager.Tensor): - no_grad_vars = [no_grad_vars] elif isinstance(no_grad_vars, (list, tuple, set)): no_grad_vars = list(no_grad_vars) for var in no_grad_vars: @@ -849,30 +869,35 @@ def to_variable(value, name=None, zero_copy=None, dtype=None): Examples: - .. code-block:: python - - import numpy as np - import paddle.base as base - - with base.dygraph.guard(base.CPUPlace()): - x = np.ones([2, 2], np.float32) - y = base.dygraph.to_variable(x, zero_copy=False) - x[0][0] = -1 - y[0][0].numpy() # array([1.], dtype=float32) - y = base.dygraph.to_variable(x) - x[0][0] = 0 - y[0][0].numpy() # array([0.], dtype=float32) - c = np.array([2+1j, 2]) - z = base.dygraph.to_variable(c) - z.numpy() # array([2.+1.j, 2.+0.j]) - z.dtype # 'complex128' - - y = base.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) - y.shape # [3L, 2L] - - y = base.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32') - y.shape # [3L, 2L] + .. code-block:: python + >>> import numpy as np + >>> import paddle.base as base + + >>> with base.dygraph.guard(base.CPUPlace()): + ... x = np.ones([2, 2], np.float32) + ... y = base.dygraph.to_variable(x, zero_copy=False) + ... x[0][0] = -1 + ... print(y[0][0].numpy()) + ... y = base.dygraph.to_variable(x) + ... x[0][0] = 0 + ... print(y[0][0].numpy()) + ... c = np.array([2+1j, 2]) + ... z = base.dygraph.to_variable(c) + ... print(z.numpy()) + ... print(z.dtype) + ... + ... y = base.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + ... print(y.shape) + ... + ... y = base.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32') + ... print(y.shape) + 1 + -1 + [2.+1.j, 2.+0.j] + paddle.complex128 + [3, 2] + [3, 2] """ support_type = ( list, diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 5972b545f93e23..172f73bf7f531f 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -150,7 +150,7 @@ def _index_(var): return int(np.array(var)) @property - def _ndim_(var): + def _ndim(var): return len(var.shape) def ndimension(var): @@ -183,7 +183,7 @@ def _T_(var): ('astype', astype), ('dim', dim), ('ndimension', ndimension), - ('ndim', _ndim_), + ('ndim', _ndim), ('size', _size_), ('T', _T_), # for logical compare diff --git a/python/paddle/base/executor.py b/python/paddle/base/executor.py index b78c19ca6cd00c..9a3138d5e38869 100755 --- a/python/paddle/base/executor.py +++ b/python/paddle/base/executor.py @@ -58,11 +58,11 @@ def global_scope(): Examples: .. code-block:: python - import paddle - import numpy + >>> import paddle + >>> import numpy - paddle.static.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), paddle.CPUPlace()) - numpy.array(paddle.static.global_scope().find_var("data").get_tensor()) + >>> paddle.static.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), paddle.CPUPlace()) + >>> numpy.array(paddle.static.global_scope().find_var("data").get_tensor()) """ return g_scope @@ -98,14 +98,16 @@ def scope_guard(scope): .. code-block:: python - import paddle - import numpy - paddle.enable_static() + >>> import paddle + >>> import numpy + >>> paddle.enable_static() - new_scope = paddle.static.Scope() - with paddle.static.scope_guard(new_scope): - paddle.static.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), paddle.CPUPlace()) - numpy.array(new_scope.find_var("data").get_tensor()) + >>> new_scope = paddle.static.Scope() + >>> with paddle.static.scope_guard(new_scope): + ... paddle.static.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), paddle.CPUPlace()) + >>> numpy.array(new_scope.find_var("data").get_tensor()) + array([[1., 1.], + [1., 1.]]) """ ex = _switch_scope(scope) @@ -123,14 +125,14 @@ def as_numpy(tensor, copy=False): Examples: .. code-block:: python - import paddle.base as base - import numpy + >>> import paddle.base as base + >>> import numpy - new_scope = base.Scope() - with base.scope_guard(new_scope): - base.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), base.CPUPlace()) - tensor = new_scope.find_var("data").get_tensor() - base.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor()) + >>> new_scope = base.Scope() + >>> with base.scope_guard(new_scope): + ... base.global_scope().var("data").get_tensor().set(numpy.ones((2, 2)), base.CPUPlace()) + >>> tensor = new_scope.find_var("data").get_tensor() + >>> base.executor.as_numpy(tensor) # or numpy.array(new_scope.find_var("data").get_tensor()) Args: tensor(Variable): a instance of Tensor @@ -670,12 +672,15 @@ def _as_lodtensor(data, place, dtype=None): For higher dimensional sequence data, please use LoDTensor directly. Examples: - >>> import paddle.base as base - >>> place = base.CPUPlace() - >>> exe = base.executor(place) - >>> data = np.array(size=(100, 200, 300)) - >>> np_outs = map(lambda x: base.executor._as_lodtensor(x, place), data) - >>> ... + + .. code-block:: python + + >>> import numpy as np + >>> import paddle.base as base + >>> place = base.CPUPlace() + >>> exe = base.Executor(place) + >>> data = np.array((100, 200, 300)) + >>> np_outs = map(lambda x: base.executor._as_lodtensor(x, place), data) Args: data(numpy.ndarray|list|tuple|scalar): a instance of array, scalar, list or tuple @@ -739,6 +744,11 @@ def _can_use_interpreter_core(program, place): return True +@lru_cache() +def _warning_once(msg): + logging.warning(msg) + + class FetchHandler: def __init__(self, var_dict=None, period_secs=60): assert var_dict is not None @@ -1044,44 +1054,45 @@ class Executor: Executor Examples: + .. code-block:: python - import paddle - import numpy - import os - - # Executor is only used in static graph mode - paddle.enable_static() - - # Set place explicitly. - # use_cuda = True - # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() - # exe = paddle.static.Executor(place) - - # If you don't set place, PaddlePaddle sets the default device. - exe = paddle.static.Executor() - - train_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(train_program, startup_program): - data = paddle.static.data(name='X', shape=[None, 1], dtype='float32') - hidden = paddle.static.nn.fc(data, 10) - loss = paddle.mean(hidden) - paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) - - # Run the startup program once and only once. - # Not need to optimize/compile the startup program. - exe.run(startup_program) - - # Run the main program directly without compile. - x = numpy.random.random(size=(10, 1)).astype('float32') - loss_data, = exe.run(train_program, feed={"X": x}, fetch_list=[loss.name]) - - # Or, compiled the program and run. See `CompiledProgram` - # for more details. - compiled_prog = paddle.static.CompiledProgram( - train_program) - loss_data, = exe.run(compiled_prog, feed={"X": x}, fetch_list=[loss.name]) + >>> import paddle + >>> import numpy + >>> import os + + >>> # Executor is only used in static graph mode + >>> paddle.enable_static() + + >>> # Set place explicitly. + >>> # use_cuda = True + >>> # place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + >>> # exe = paddle.static.Executor(place) + + >>> # If you don't set place, PaddlePaddle sets the default device. + >>> exe = paddle.static.Executor() + + >>> train_program = paddle.static.Program() + >>> startup_program = paddle.static.Program() + >>> with paddle.static.program_guard(train_program, startup_program): + ... data = paddle.static.data(name='X', shape=[None, 1], dtype='float32') + ... hidden = paddle.static.nn.fc(data, 10) + ... loss = paddle.mean(hidden) + ... paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) + ... + >>> # Run the startup program once and only once. + >>> # Not need to optimize/compile the startup program. + >>> exe.run(startup_program) + + >>> # Run the main program directly without compile. + >>> x = numpy.random.random(size=(10, 1)).astype('float32') + >>> loss_data, = exe.run(train_program, feed={"X": x}, fetch_list=[loss.name]) + + >>> # Or, compiled the program and run. See `CompiledProgram` + >>> # for more details. + >>> compiled_prog = paddle.static.CompiledProgram( + ... train_program) + >>> loss_data, = exe.run(compiled_prog, feed={"X": x}, fetch_list=[loss.name]) """ @@ -1171,10 +1182,8 @@ def _add_micro_scopes_cache(self, program_cache_key, micro_scopes: list): def _get_micro_scopes_cache(self, program_cache_key): return self.micro_scope_cache.get(program_cache_key, None) - # just for testing, will be removed later - @lru_cache() def _log_force_set_program_cache(self, use_program_cache): - logging.warning( + _warning_once( f"use_program_cache is force set to {use_program_cache} by FLAGS_FORCE_USE_PROGRAM_CACHE" ) @@ -1440,14 +1449,15 @@ def close(self): None Examples: + .. code-block:: python - import paddle + >>> import paddle - cpu = paddle.CPUPlace() - exe = paddle.static.Executor(cpu) - # execute training or testing - exe.close() + >>> cpu = paddle.CPUPlace() + >>> exe = paddle.static.Executor(cpu) + >>> # execute training or testing + >>> exe.close() """ if not self._closed: self._closed = True @@ -1519,78 +1529,82 @@ def run( List: The fetched result list. Examples: + .. code-block:: python :name: code-example-1 - import paddle - import numpy - - # First create the Executor. - paddle.enable_static() - place = paddle.CPUPlace() # paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - - data = paddle.static.data(name='X', shape=[None, 1], dtype='float32') - hidden = paddle.static.nn.fc(data, 10) - loss = paddle.mean(hidden) - adam = paddle.optimizer.Adam() - adam.minimize(loss) - i = paddle.zeros(shape=[1], dtype='int64') - array = paddle.tensor.array_write(x=loss, i=i) - - # Run the startup program once and only once. - exe.run(paddle.static.default_startup_program()) - - x = numpy.random.random(size=(10, 1)).astype('float32') - loss_val, array_val = exe.run(feed={'X': x}, - fetch_list=[loss.name, array.name]) - print(array_val) - # [array([0.02153828], dtype=float32)] + >>> import paddle + >>> import numpy + + >>> # First create the Executor. + >>> paddle.enable_static() + >>> place = paddle.CPUPlace() # paddle.CUDAPlace(0) + >>> exe = paddle.static.Executor(place) + + >>> data = paddle.static.data(name='X', shape=[None, 1], dtype='float32') + >>> hidden = paddle.static.nn.fc(data, 10) + >>> loss = paddle.mean(hidden) + >>> adam = paddle.optimizer.Adam() + >>> adam.minimize(loss) + >>> i = paddle.zeros(shape=[1], dtype='int64') + >>> array = paddle.tensor.array_write(x=loss, i=i) + + >>> # Run the startup program once and only once. + >>> exe.run(paddle.static.default_startup_program()) + + >>> x = numpy.random.random(size=(10, 1)).astype('float32') + >>> loss_val, array_val = exe.run(feed={'X': x}, + ... fetch_list=[loss.name, array.name]) + >>> print(array_val) + >>> # doctest: +SKIP("Random output") + [array(0.16870381, dtype=float32)] + >>> # doctest: -SKIP .. code-block:: python :name: code-example-2 - # required: gpu - import paddle - import numpy as np - - # First create the Executor. - paddle.enable_static() - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - - data = paddle.static.data(name='X', shape=[None, 1], dtype='float32') - class_dim = 2 - prediction = paddle.static.nn.fc(data, class_dim) - loss = paddle.mean(prediction) - adam = paddle.optimizer.Adam() - adam.minimize(loss) - - # Run the startup program once and only once. - exe.run(paddle.static.default_startup_program()) - build_strategy = paddle.static.BuildStrategy() - binary = paddle.static.CompiledProgram( - paddle.static.default_main_program(), build_strategy=build_strategy) - batch_size = 6 - x = np.random.random(size=(batch_size, 1)).astype('float32') - - prediction, = exe.run(binary, - feed={'X': x}, - fetch_list=[prediction.name]) - # If the user uses two GPU cards to run this python code, the printed result will be - # (6, class_dim). The first dimension value of the printed result is the batch_size. - print("The prediction shape: {}".format( - np.array(prediction).shape)) - print(prediction) - - # Out: - # The prediction shape: (6, 2) - # [[-0.37789783 -0.19921964] - # [-0.3577645 -0.18863106] - # [-0.24274671 -0.12814042] - # [-0.24635398 -0.13003758] - # [-0.49232286 -0.25939852] - # [-0.44514108 -0.2345845 ]] + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> import numpy as np + + >>> # First create the Executor. + >>> paddle.enable_static() + >>> place = paddle.CUDAPlace(0) + >>> exe = paddle.static.Executor(place) + + >>> data = paddle.static.data(name='X', shape=[None, 1], dtype='float32') + >>> class_dim = 2 + >>> prediction = paddle.static.nn.fc(data, class_dim) + >>> loss = paddle.mean(prediction) + >>> adam = paddle.optimizer.Adam() + >>> adam.minimize(loss) + + >>> # Run the startup program once and only once. + >>> exe.run(paddle.static.default_startup_program()) + >>> build_strategy = paddle.static.BuildStrategy() + >>> binary = paddle.static.CompiledProgram( + ... paddle.static.default_main_program(), build_strategy=build_strategy) + >>> batch_size = 6 + >>> x = np.random.random(size=(batch_size, 1)).astype('float32') + + >>> prediction, = exe.run(binary, + ... feed={'X': x}, + ... fetch_list=[prediction.name]) + >>> # If the user uses two GPU cards to run this python code, the printed result will be + >>> # (6, class_dim). The first dimension value of the printed result is the batch_size. + >>> print("The prediction shape: {}".format( + ... np.array(prediction).shape)) + The prediction shape: (6, 2) + + >>> print(prediction) + >>> # doctest: +SKIP("Random output") + [[-0.37789783 -0.19921964] + [-0.3577645 -0.18863106] + [-0.24274671 -0.12814042] + [-0.24635398 -0.13003758] + [-0.49232286 -0.25939852] + [-0.44514108 -0.2345845 ]] + >>> # doctest: -SKIP """ # Temporary FLAGS, just for testing the performance of program cache @@ -2711,7 +2725,7 @@ def _run_using_fleet_executor( if return_numpy: tensor = as_numpy(tensor) else: - tensor = [t for t in tensor] + tensor = list(tensor) if tensor: scope_result_list.append(tensor) @@ -2909,23 +2923,22 @@ def infer_from_dataset( .. code-block:: python - import paddle - - paddle.enable_static() - place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu - exe = paddle.static.Executor(place) - x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64") - y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1) - dataset = paddle.base.DatasetFactory().create_dataset() - dataset.set_use_var([x, y]) - dataset.set_thread(1) - # you should set your own filelist, e.g. filelist = ["dataA.txt"] - filelist = [] - dataset.set_filelist(filelist) - exe.run(paddle.static.default_startup_program()) - exe.infer_from_dataset(program=paddle.static.default_main_program(), - dataset=dataset) - + >>> import paddle + + >>> paddle.enable_static() + >>> place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu + >>> exe = paddle.static.Executor(place) + >>> x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64") + >>> y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1) + >>> dataset = paddle.base.DatasetFactory().create_dataset() + >>> dataset.set_use_var([x, y]) + >>> dataset.set_thread(1) + >>> # you should set your own filelist, e.g. filelist = ["dataA.txt"] + >>> filelist = [] + >>> dataset.set_filelist(filelist) + >>> exe.run(paddle.static.default_startup_program()) + >>> exe.infer_from_dataset(program=paddle.static.default_main_program(), + ... dataset=dataset) """ return self._run_from_dataset( program, @@ -3032,23 +3045,22 @@ def train_from_dataset( .. code-block:: python - import paddle - - paddle.enable_static() - place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu - exe = paddle.static.Executor(place) - x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64") - y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1) - dataset = paddle.base.DatasetFactory().create_dataset() - dataset.set_use_var([x, y]) - dataset.set_thread(1) - # you should set your own filelist, e.g. filelist = ["dataA.txt"] - filelist = [] - dataset.set_filelist(filelist) - exe.run(paddle.static.default_startup_program()) - exe.train_from_dataset(program=paddle.static.default_main_program(), - dataset=dataset) - + >>> import paddle + + >>> paddle.enable_static() + >>> place = paddle.CPUPlace() # you can set place = paddle.CUDAPlace(0) to use gpu + >>> exe = paddle.static.Executor(place) + >>> x = paddle.static.data(name="x", shape=[None, 10, 10], dtype="int64") + >>> y = paddle.static.data(name="y", shape=[None, 1], dtype="int64", lod_level=1) + >>> dataset = paddle.base.DatasetFactory().create_dataset() + >>> dataset.set_use_var([x, y]) + >>> dataset.set_thread(1) + >>> # you should set your own filelist, e.g. filelist = ["dataA.txt"] + >>> filelist = [] + >>> dataset.set_filelist(filelist) + >>> exe.run(paddle.static.default_startup_program()) + >>> exe.train_from_dataset(program=paddle.static.default_main_program(), + ... dataset=dataset) """ return self._run_from_dataset( program, diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index 3aea7e6a85a8a5..ca9bcf5fd8db5b 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -466,13 +466,13 @@ def require_version(min_version, max_version=None): Examples: .. code-block:: python - >>> import paddle.base as base + >>> import paddle >>> # any version >= 0.1.0 is acceptable. - >>> base.require_version('0.1.0') + >>> paddle.utils.require_version('0.1.0') >>> # if 0.1.0 <= version <= 10.0.0, it is acceptable. - >>> base.require_version(min_version='0.1.0', max_version='10.0.0') + >>> paddle.utils.require_version(min_version='0.1.0', max_version='10.0.0') """ if not isinstance(min_version, str): raise TypeError( diff --git a/python/paddle/base/layers/layer_function_generator.py b/python/paddle/base/layers/layer_function_generator.py index f77d26ac50a5f9..2cec3b7e58fa17 100644 --- a/python/paddle/base/layers/layer_function_generator.py +++ b/python/paddle/base/layers/layer_function_generator.py @@ -193,7 +193,7 @@ def infer_and_check_dtype(op_proto, *args, **kwargs): dtype = each.dtype elif dtype != each.dtype: raise ValueError( - "operator {0} must input same dtype. {1} vs {2}".format( + "operator {} must input same dtype. {} vs {}".format( op_type, dtype, each.dtype ) ) @@ -337,8 +337,8 @@ def func(x, name=None): func.__name__ = inplace_op_type func.__doc__ = """ -Inplace version of ``{0}`` API, the output Tensor will be inplaced with input ``x``. -Please refer to :ref:`api_base_layers_{1}`. +Inplace version of ``{}`` API, the output Tensor will be inplaced with input ``x``. +Please refer to :ref:`api_base_layers_{}`. """.format( origin_op_type, origin_op_type ) diff --git a/python/paddle/base/layers/math_op_patch.py b/python/paddle/base/layers/math_op_patch.py index f2b1ac7c6d04d1..1f070882758b92 100644 --- a/python/paddle/base/layers/math_op_patch.py +++ b/python/paddle/base/layers/math_op_patch.py @@ -355,7 +355,7 @@ def pop(self, *args): if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: raise TypeError( - "Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}".format( + "Only Variable with VarType.LOD_TENSOR_ARRAY support `pop` method, but received type: {}".format( self.type ) ) @@ -376,7 +376,7 @@ def _neg_(var): return _scalar_op_(var, -1.0, 0.0) @property - def _ndim_(self): + def _ndim(self): """ Returns the dimension of current Variable @@ -393,7 +393,7 @@ def _ndim_(self): >>> # create a static Variable >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) >>> # print the dimension of the Variable - >>> print(x.ndim()) + >>> print(x.ndim) 3 """ return len(self.shape) @@ -627,7 +627,7 @@ def to_dense(var): ('pop', pop), ('dim', dim), ('ndimension', ndimension), - ('ndim', _ndim_), + ('ndim', _ndim), ( '__add__', _binary_creator_('__add__', 'elementwise_add', False, _scalar_add_), diff --git a/python/paddle/base/multiprocess_utils.py b/python/paddle/base/multiprocess_utils.py index 8d18db0bb3ea85..9b70cacd1c2cd8 100644 --- a/python/paddle/base/multiprocess_utils.py +++ b/python/paddle/base/multiprocess_utils.py @@ -73,7 +73,6 @@ def _func_register(function): if not callable(function): raise TypeError("%s is not callable object." % (function)) # check function object whether hash-able - set([function]) if function not in cls._registered_func_set: atexit.register(_func_exectuor) cls._registered_func_set.add(function) diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index ea43505894f5bc..52e8c43305cf1a 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -14,6 +14,7 @@ import itertools import warnings +from functools import reduce import numpy as np @@ -224,7 +225,8 @@ def replace_ellipsis(var, item): item_remove_var = [ ele for ele in item - if not isinstance(ele, (Variable, np.ndarray)) and ele is not None + if not isinstance(ele, (Variable, paddle.pir.OpResult, np.ndarray)) + and ele is not None ] ell_count = item_remove_var.count(Ellipsis) if ell_count == 0: @@ -284,6 +286,9 @@ def is_integer_or_scalar_tensor(ele): return True if len(ele.shape) == 0 and ele.dtype != paddle.bool: return True + elif isinstance(ele, paddle.pir.OpResult): + if len(ele.shape) == 0 and ele.dtype != paddle.base.libpaddle.BOOL: + return True return False @@ -292,6 +297,11 @@ def is_bool_tensor(ele): if isinstance(ele, Variable) and ele.dtype == paddle.bool: return True + elif ( + isinstance(ele, paddle.pir.OpResult) + and ele.dtype == paddle.base.libpaddle.BOOL + ): + return True return False @@ -303,7 +313,7 @@ def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags): attr, dtype="int64" ) for i, dim in enumerate(attr): - if isinstance(dim, Variable): + if isinstance(dim, (Variable, paddle.pir.OpResult)): attrs[attr_name].append(-1) infer_flags[i] = -1 else: @@ -335,14 +345,10 @@ def get_value_for_bool_tensor(var, item): empty_shape = [0] + list(var.shape[i:]) def idx_not_empty(var, item): - from ..tensor import gather_nd - - bool_2_idx = paddle.nonzero(item == True) - return gather_nd(var, bool_2_idx) - - from paddle.static.nn import cond + bool_2_idx = paddle.nonzero(item) + return paddle.gather_nd(var, bool_2_idx) - return cond( + return paddle.static.nn.cond( item.any(), lambda: idx_not_empty(var, item), lambda: paddle.empty(empty_shape, var.dtype), @@ -758,9 +764,14 @@ def parse_index(x, indices): has_advanced_index = True estimated_dim += 1 - elif isinstance(slice_item, paddle.base.Variable): + elif isinstance( + slice_item, (paddle.base.Variable, paddle.pir.OpResult) + ): # In this case, the Variable is not 0-dim Tensor and will be treated as advanced-indexing. - if slice_item.dtype == paddle.bool: + if ( + slice_item.dtype == paddle.bool + or slice_item.dtype == paddle.base.libpaddle.BOOL + ): if slice_item.ndim == 0: # 0-D bool Tensor, same as single PY-bool. none_axes.append(dim) @@ -788,7 +799,12 @@ def parse_index(x, indices): axes.append(dim) use_strided_slice = ( True - if (isinstance(step, paddle.base.Variable) or step != 1) + if ( + isinstance( + step, (paddle.base.Variable, paddle.pir.OpResult) + ) + or step != 1 + ) else use_strided_slice ) return ( diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index bfb9b6e9ba2c66..16b18a071358b4 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -16,7 +16,8 @@ import typing from paddle import pir -from paddle.base.libpaddle.pir import Block, Program +from paddle.base.core import call_decomp, has_decomp +from paddle.base.libpaddle.pir import Block, Operation, Program from paddle.framework import core from . import register @@ -30,6 +31,18 @@ def _build_tensor_tuple(xs): return TypeError(f"Type {type(xs)} is not supported.") +def _analyse_decomp_results(orig_outs, decomp_outs): + assert len(orig_outs) == len(decomp_outs) + res = [] + for org_item, new_item in zip(orig_outs, decomp_outs): + if isinstance(org_item, pir.OpResult): + assert len(new_item) == 1 and isinstance(new_item[0], pir.OpResult) + res.append(new_item[0]) + else: + res.append(new_item) + return res + + def _prepare_python_api_arguments(op): """ For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs. @@ -37,7 +50,16 @@ def _prepare_python_api_arguments(op): Args: op (Operator): The target operator. """ - op_inputs = [x.source() for x in op.operands()] + op_inputs = [] + for x in op.operands(): + op_input = x.source() + upper_op = op_input.get_defining_op() + if ( + isinstance(upper_op, Operation) + and upper_op.name() == 'builtin.combine' + ): + op_input = [item.source() for item in upper_op.operands()] + op_inputs.append(op_input) # The inputs of PIR op builtin.combine will be restored as list of tensor. if op.name() in ["builtin.combine"]: return (op_inputs,) @@ -203,15 +225,14 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): if isinstance(block, Block): ops_list = block.ops temp_op = None - temp_inputs = None for idx, op in enumerate(ops_list): op_name = op.name() decom_rule = register.get_decomp_rule(op_name) - lower = decom_rule and op_filter(op) + has_sink_decomp_rule = has_decomp(op) + lower = (decom_rule or has_sink_decomp_rule) and op_filter(op) if op.name() == "builtin.combine": temp_op = op - temp_inputs = _prepare_python_api_arguments(op) if lower: core.prim_config["composite_ops_record"].add(op_name) @@ -219,20 +240,25 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter): temp_op is not None and ops_list[idx - 1].name() == "builtin.combine" ): - input_args = temp_inputs pir.set_insertion_point(temp_op) else: - input_args = _prepare_python_api_arguments(op) pir.set_insertion_point(op) + input_args = _prepare_python_api_arguments(op) orig_outs = op.results() - new_outs = _build_tensor_tuple(decom_rule(*input_args)) + if has_sink_decomp_rule: + decomp_outs = call_decomp(op) + new_outs = _analyse_decomp_results(orig_outs, decomp_outs) + else: + new_outs = _build_tensor_tuple(decom_rule(*input_args)) # Todo: To cover such case: some outputs are no longer needed after decomposition. _check_op_results( op_name, orig_outs, new_outs, orig_vars, dst_vars ) - - op.replace_all_uses_with(new_outs) + if op.name() in ("pd_op.unsqueeze", "pd_op.squeeze"): + orig_outs[0].replace_all_uses_with(new_outs[0]) + else: + op.replace_all_uses_with(new_outs) block.remove_op(op) if temp_op is not None: diff --git a/python/paddle/decomposition/rules.py b/python/paddle/decomposition/rules.py index 924ccf1756b0e0..d64cba8d657ba1 100644 --- a/python/paddle/decomposition/rules.py +++ b/python/paddle/decomposition/rules.py @@ -18,7 +18,6 @@ from .register import register_decomp -@register_decomp('pd_op.mean') def mean(x, axis, keepdim): """define composite rule of op mean""" x_shape = x.shape @@ -56,15 +55,32 @@ def gelu(x, approximate): tanh_out = tanh(kAlpha * (x + GELU_CONSTANT * x * x * x)) out = x * half * (one + tanh_out) return out - else: # gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - cdf = half * (one + _pir_ops.erf(x * full(x.shape, M_SQRT1_2, x.dtype))) out = x * cdf return out +@register_decomp('pd_op.sqrt') +def sqrt(x): + """ + define composite rule of op sqrt + res = pow(x, 0.5) + """ + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + + y = full(x.shape if len(x.shape) == 0 else [1], 0.5, x.dtype) + res = pow_composite(x, y) + return res if not is_amp else cast(res, dtype) + + @register_decomp('pd_op.rsqrt') def rsqrt(x): """define composite rule of op rsqrt.""" @@ -211,3 +227,120 @@ def add_n(x): for xi in x[1:]: ans = xi + ans return ans + + +@register_decomp('pd_op.silu') +def silu(x): + """ + define composite rule of op silu + res = x / (1 + exp(-x)) + """ + is_amp = False + from paddle.base.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + + sum_temp = exp(-x) + 1 + res = x / sum_temp + return res if not is_amp else cast(res, dtype) + + +@register_decomp('pd_op.softmax') +def softmax(x, axis): + """define composite rule of op softmax""" + is_amp = False + from paddle.base.data_feeder import convert_dtype + + # Softmax need fp32 compute since it has sum op in + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + if not x.shape: + # do not return 1, to ensure gradients + res = exp(x - x) + if is_amp: + res = cast(res, "float16") + return res + max_temp = max(x, axis, keepdim=True) + max_temp.stop_gradient = True + molecular = exp(x - max_temp) + denominator = sum(molecular, axis=axis, keepdim=True) + res = divide(molecular, denominator) + if is_amp: + res = cast(res, dtype) + return res + + +@register_decomp('pd_op.full_like') +def full_like(x, fill_value, dtype, place=None): + """define composite rule of op full_like.""" + """op name: full_like op type name: fill_any_like.""" + """arg place is not used, add it here to keep same as python api.""" + fill_value = fill_value.get_defining_op().attrs()["value"] + val = full(x.shape, fill_value, dtype) + return val + + +@register_decomp('pd_op.stack') +def stack(x, axis): + """ + define composite rule of op stack + unsqueeze each dimension of the input (use reshape), and then concat + """ + x_shape = x[0].shape + if axis < 0: + axis += len(x_shape) + 1 + out_shape = x_shape[:axis] + [1] + x_shape[axis:] + out = concat([reshape(item, out_shape) for item in x], axis) + return out + + +@register_decomp('pd_op.squeeze') +def squeeze(x, axis): + """define composite rule of squeeze""" + """ + canonicalize dim within range 0 to rank and + determine new shape after squeeze op + if axis not specified, remove all dims equal to 1 + otherwise, remove dims equal to 1 in axis + axis can only be list, not int + """ + axis = axis.get_defining_op().attrs()["value"] + rank = len(x.shape) + if rank == 0: + return [assign(x), None] + if len(axis) == 0: + dims = set(range(rank)) + else: + dims = {ax % rank for ax in axis} + new_shape = [] + for d, s in enumerate(x.shape): + if not (s == 1 and (d in dims)): + new_shape.append(s) + out = reshape(x, new_shape) + return [out, None] + + +@register_decomp('pd_op.unsqueeze') +def unsqueeze(x, axis): + """define composite rule of op unsqueeze""" + """using reshape to implement unsqueeze op""" + axis = axis.get_defining_op().attrs()["value"] + x_shape = list(x.shape) + axis_list = list(axis) + for i in axis_list: + if i < 0: + i += len(x_shape) + 1 + x_shape = ( + x_shape[:i] + + [ + 1, + ] + + x_shape[i:] + ) + out = reshape(x, x_shape) + return [out, None] diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index ac452565624409..999ad0cf94f926 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -729,7 +729,7 @@ def _optimization_tuning(self, mode, dataset, batch_size): def _plan(self, mode): if self._planned_mode is None: self._planned_mode = mode - else: + elif self._strategy.auto_mode != "semi": self._init_dist_context(mode) self._planners[mode] = Planner(mode, self._dist_contexts[mode]) @@ -1132,43 +1132,69 @@ def evaluate( else: self._switch_mode(self._mode) - micro_batch_size = self._validate_batch_size(batch_size) - valid_dataloader = self._prepare_dataloader_from_generator( - dataset=valid_data, - capacity=70, - iterable=False, - batch_size=micro_batch_size, - steps_per_epoch=steps, - collate_fn=collate_fn, - ) + if auto_utils.use_new_executor(): + local_batch_size = self._validate_batch_size(batch_size) + valid_dataloader = self._prepare_dataloader( + valid_data, + return_list=False, + batch_size=local_batch_size, + collate_fn=collate_fn, + ) + steps_per_epoch = len(valid_dataloader) if steps is None else steps + else: + micro_batch_size = self._validate_batch_size(batch_size) + valid_dataloader = self._prepare_dataloader_from_generator( + dataset=valid_data, + capacity=70, + iterable=False, + batch_size=micro_batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn, + ) + steps_per_epoch = valid_dataloader._steps + local_batch_size = micro_batch_size + if self._strategy.pipeline.enable: + local_batch_size = micro_batch_size * self._acc_steps fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) cbks = config_callbacks( callbacks, engine=self, - batch_size=micro_batch_size, + batch_size=local_batch_size, log_freq=log_freq, verbose=verbose, metrics=self._metrics_name(), ) - eval_steps = valid_dataloader._steps + eval_steps = steps_per_epoch cbks.on_begin( 'eval', {'steps': eval_steps, 'metrics': self._metrics_name()} ) logs = {} - for step, _ in enumerate(valid_dataloader): - cbks.on_batch_begin('eval', step, logs) + for step, batch in enumerate(valid_dataloader): + if auto_utils.use_new_executor(): + batches = self._validate_batch(batch) + else: + batches = [{}] + try: - outs = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy, - ) + for micro_batch in batches: + cbks.on_batch_begin('eval', step, logs) + outs = self._executor.run( + self.main_program, + feed=micro_batch, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break + + if steps_per_epoch and step >= steps_per_epoch: + if not auto_utils.use_new_executor(): + valid_dataloader._reset() + break logs = self._prepare_logger( outs, None, step, None, fetch_names, fetch_indices, self._mode ) @@ -1240,34 +1266,57 @@ def predict( else: self._switch_mode(self._mode) - micro_batch_size = self._validate_batch_size(batch_size) - test_dataloader = self._prepare_dataloader_from_generator( - dataset=test_data, - capacity=70, - iterable=False, - batch_size=micro_batch_size, - steps_per_epoch=steps, - collate_fn=collate_fn, - ) + if auto_utils.use_new_executor(): + local_batch_size = self._validate_batch_size(batch_size) + test_dataloader = self._prepare_dataloader( + test_data, + return_list=False, + batch_size=local_batch_size, + collate_fn=collate_fn, + ) + steps_per_epoch = len(test_dataloader) if steps is None else steps + else: + micro_batch_size = self._validate_batch_size(batch_size) + test_dataloader = self._prepare_dataloader_from_generator( + dataset=test_data, + capacity=70, + iterable=False, + batch_size=micro_batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn, + ) + steps_per_epoch = test_dataloader._steps fetch_names, fetch_indices = self._prepare_fetch(None, mode=self._mode) outputs = [] cbks = config_callbacks(callbacks, engine=self, verbose=verbose) - test_steps = test_dataloader._steps + test_steps = steps_per_epoch cbks.on_begin('predict', {'steps': test_steps}) logs = {} - for step, _ in enumerate(test_dataloader): - cbks.on_batch_begin('predict', step, logs) + for step, batch in enumerate(test_dataloader): + if auto_utils.use_new_executor(): + batches = self._validate_batch(batch) + else: + batches = [{}] + try: - outs = self._executor.run( - self.main_program, - fetch_list=fetch_names, - use_program_cache=self._strategy.use_cache, - return_numpy=self._strategy.return_numpy, - ) + for micro_batch in batches: + cbks.on_batch_begin('predict', step, logs) + outs = self._executor.run( + self.main_program, + feed=micro_batch, + fetch_list=fetch_names, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy, + ) except core.EOFException: break + + if steps_per_epoch and step >= steps_per_epoch: + if not auto_utils.use_new_executor(): + test_dataloader._reset() + break logs = self._prepare_logger( outs, None, step, None, fetch_names, fetch_indices, self._mode ) @@ -1281,7 +1330,7 @@ def dataloader( dataset, batch_size=1, shuffle=False, - drop_last=False, + drop_last=True, collate_fn=None, num_workers=0, use_buffer_reader=True, @@ -1451,7 +1500,7 @@ def _prepare_dataloader( return_list=True, batch_size=1, shuffle=False, - drop_last=False, + drop_last=True, collate_fn=None, num_workers=0, use_buffer_reader=True, diff --git a/python/paddle/distributed/auto_parallel/static/operators/common.py b/python/paddle/distributed/auto_parallel/static/operators/common.py index c9b0acc24d24e2..7366d65c0ea895 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/common.py +++ b/python/paddle/distributed/auto_parallel/static/operators/common.py @@ -25,6 +25,7 @@ _get_corresponding_rank, compute_compatible_dims_mapping, is_optimize_op, + set_dist_op_desc_original_id, ) _logger = get_logger( @@ -718,3 +719,15 @@ def get_default_distributed_operator_impl(): num_impls = len(dist_op_default_impl_container.impls) assert num_impls == 1, f"Default dist op has [{num_impls}] impls" return dist_op_default_impl_container.get_impl(0) + + +def copy_op_without_infer_shape(src_op, block, ctx, varname_kwargs): + new_op = block.append_op(type='nop') + new_op_desc = new_op.desc + new_op_desc.copy_from(src_op.desc) + set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) + for input_name in src_op.desc.input_names(): + new_op_desc.set_input(input_name, varname_kwargs[input_name]) + for output_name in src_op.desc.output_names(): + new_op_desc.set_output(output_name, varname_kwargs[output_name]) + return new_op diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py index d83beb82cd12a1..841dc0a5870444 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_flash_attn.py @@ -18,7 +18,7 @@ register_distributed_operator_impl, register_distributed_operator_impl_container, ) -from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0 +from .dist_eltwise import DistributedElementwiseImpl0 class DistributedFlashAttn(DistributedOperatorImplContainer): @@ -30,6 +30,7 @@ def __init__(self, op_type): # Dist FlashAttn with Random Control +# NOTE(zhiqiu): trick implementation, copy dist_attr of q,k,v to out class DistributedFlashAttnImpl0(DistributedElementwiseImpl0): def __init__(self, name): super().__init__(name) @@ -83,12 +84,12 @@ def forward(ctx, *args, **kwargs): src_op._set_attr('rng_name', rng_name) - DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + DistributedElementwiseImpl0.forward(ctx, *args, **kwargs) @staticmethod def backward(ctx, *args, **kwargs): # dropout backward is deterministic by mask, and not need for random state control - DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + DistributedElementwiseImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py index 5066aaa82db896..40b0f109c78037 100644 --- a/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/static/operators/dist_matmul.py @@ -14,7 +14,7 @@ import copy -from paddle.common_ops_import import check_dtype, check_variable_and_dtype +from paddle.common_ops_import import check_variable_and_dtype from paddle.distributed.auto_parallel.static.cost.comm_op_cost import ( AllreduceSumOpCost, IdentityOpCost, @@ -53,8 +53,8 @@ from .common import ( DistributedOperatorImpl, DistributedOperatorImplContainer, + copy_op_without_infer_shape, gradient_synchronization, - infer_shape, is_parameter_related, register_distributed_operator_impl, register_distributed_operator_impl_container, @@ -858,37 +858,16 @@ def forward(ctx, *args, **kwargs): assert x_tensor_dist_attr is not None identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) assert identity_var_dist_attr is not None - ref_shape_x = infer_shape( - main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr - ) + # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape_out = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - check_variable_and_dtype( - X_var, - 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - '_c_identity', - ) - - attrs = { - 'transpose_X': trans_x, - 'transpose_Y': trans_y, - 'alpha': 1, - OP_ROLE_KEY: src_op.attr('op_role'), - } - inputs = {'X': [X_var], 'Y': [Weight_var]} - matmul_op = main_block.append_op( - type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs - ) - if Out_var.shape != ref_shape_out: - Out_var.desc.set_shape(ref_shape_out) + # copy op + matmul_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs) + matmul_op._set_attr('alpha', 1) # matmul matmul_op_dist_attr = OperatorDistAttr() @@ -1166,61 +1145,19 @@ def forward(ctx, *args, **kwargs): ) group = new_process_group(group_ranks) - check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' - ) - check_dtype( - X_var.dtype, - 'dtype', - ['float16', 'float32', 'float64', 'uint16'], - 'linear', - ) - attrs = { - 'transpose_X': trans_x, - 'transpose_Y': trans_y, - 'alpha': 1, - OP_ROLE_KEY: src_op.attr('op_role'), - } - inputs = {'X': X_var, 'Y': Weight_var} - # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - - intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key( - ".".join(["c_allreduce_sum", 'tmp']) - ), - shape=Out_var.shape, - dtype=Out_var.dtype, - type=Out_var.type, - lod_level=Out_var.lod_level, - persistable=False, - is_data=False, - need_check_feed=Out_var.desc.need_check_feed(), - ) - # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program( - intermediate_var_0, out_var_dist_attr - ) - - matmul_op = main_block.append_op( - type='matmul', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs, - ) - if intermediate_var_0.shape != ref_shape: - intermediate_var_0.desc.set_shape(ref_shape) + # copy op + matmul_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs) + + # add allreduce (inplace) c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', - inputs={'X': intermediate_var_0}, + inputs={'X': Out_var}, outputs={'Out': Out_var}, attrs={ 'ring_id': group.id, @@ -1229,8 +1166,6 @@ def forward(ctx, *args, **kwargs): OP_ROLE_KEY: src_op.attr('op_role'), }, ) - if Out_var.shape != ref_shape: - Out_var.desc.set_shape(ref_shape) # set dist op's dist_attr with serial op's dist_attr # matmul @@ -1673,55 +1608,25 @@ def forward(ctx, *args, **kwargs): ), "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( matmul_col_dim_mapping ) - process_mesh_shape = op_dist_attr.process_mesh.shape - process_mesh_group = op_dist_attr.process_mesh.process_ids - - parallel_axis = matmul_col_dim_mapping - group_ranks = _get_comm_group( - process_mesh_group, process_mesh_shape, parallel_axis, rank_id - ) - group = new_process_group(group_ranks) # infer new var shape with op dist attr x_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(X_var) assert x_tensor_dist_attr is not None identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) assert identity_var_dist_attr is not None - ref_shape_x = infer_shape( - main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr - ) + # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape_out = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - - check_variable_and_dtype( - X_var, - 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - '_c_identity', - ) - - attrs = { - 'trans_x': trans_x, - 'trans_y': trans_y, - OP_ROLE_KEY: src_op.attr('op_role'), - } - inputs = {'X': [X_var], 'Y': [Weight_var]} - matmul_v2_op = main_block.append_op( - type='matmul_v2', - inputs=inputs, - outputs={'Out': Out_var}, - attrs=attrs, + + # copy op + matmul_v2_op = copy_op_without_infer_shape( + src_op, main_block, ctx, kwargs ) - if Out_var.shape != ref_shape_out: - Out_var.desc.set_shape(ref_shape_out) - # matmulv2 + # set distattr matmulv2_op_dist_attr = OperatorDistAttr() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type @@ -1995,60 +1900,20 @@ def forward(ctx, *args, **kwargs): ) group = new_process_group(group_ranks) - check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' - ) - check_dtype( - X_var.dtype, - 'dtype', - ['float16', 'float32', 'float64', 'uint16'], - 'linear', - ) - attrs = { - 'trans_x': trans_x, - 'trans_y': trans_y, - OP_ROLE_KEY: src_op.attr('op_role'), - } - inputs = {'X': X_var, 'Y': Weight_var} - # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - - intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key( - ".".join(["c_allreduce_sum", 'tmp']) - ), - shape=Out_var.shape, - dtype=Out_var.dtype, - type=Out_var.type, - lod_level=Out_var.lod_level, - persistable=False, - is_data=False, - need_check_feed=Out_var.desc.need_check_feed(), - ) - # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program( - intermediate_var_0, out_var_dist_attr - ) - - matmul_v2_op = main_block.append_op( - type='matmul_v2', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs, + + # copy op + matmul_v2_op = copy_op_without_infer_shape( + src_op, main_block, ctx, kwargs ) - if intermediate_var_0.shape != ref_shape: - intermediate_var_0.desc.set_shape(ref_shape) c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', - inputs={'X': intermediate_var_0}, + inputs={'X': Out_var}, outputs={'Out': Out_var}, attrs={ 'ring_id': group.id, @@ -2057,11 +1922,8 @@ def forward(ctx, *args, **kwargs): OP_ROLE_KEY: src_op.attr('op_role'), }, ) - if Out_var.shape != ref_shape: - Out_var.desc.set_shape(ref_shape) # set dist op's dist_attr with serial op's dist_attr - # matmulv2 matmulv2_op_dist_attr = OperatorDistAttr() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type @@ -2501,60 +2363,17 @@ def forward(ctx, *args, **kwargs): assert x_tensor_dist_attr is not None identity_var_dist_attr = op_dist_attr.get_input_dist_attr(X_var.name) assert identity_var_dist_attr is not None - ref_shape_x = infer_shape( - main_block, X_var, x_tensor_dist_attr, identity_var_dist_attr - ) + # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape_out = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - - check_variable_and_dtype( - X_var, - 'tensor', - ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'], - '_c_identity', - ) - - attrs = { - "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), - "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), - OP_ROLE_KEY: src_op.attr('op_role'), - } - inputs = {'X': X_var, 'Y': Weight_var} - - inputs_ref_shape = {} - inputs_original_shape = {} - for var_name in inputs: - if var_name == "X": - var = X_var - else: - var = inputs[var_name] - inputs_original_shape[var_name] = var.shape - input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) - input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) - input_ref_shape = infer_shape( - main_block, var, input_tensor_dist_attr, input_var_dist_attr - ) - inputs_ref_shape[var_name] = input_ref_shape - var.desc.set_shape(input_ref_shape) - - mul_op = main_block.append_op( - type='mul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs - ) - if Out_var.shape != ref_shape_out: - Out_var.desc.set_shape(ref_shape_out) - for var_name in inputs: - var = inputs[var_name] - original_shape = inputs_original_shape[var_name] - var.desc.set_shape(original_shape) + # copy op + mul_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs) - # matmulv2 + # set distattr matmulv2_op_dist_attr = OperatorDistAttr() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type @@ -2816,80 +2635,18 @@ def forward(ctx, *args, **kwargs): ) group = new_process_group(group_ranks) - check_variable_and_dtype( - X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear' - ) - check_dtype( - X_var.dtype, - 'dtype', - ['float16', 'float32', 'float64', 'uint16'], - 'linear', - ) - # attrs = {'trans_x': False, 'trans_y': False} - attrs = { - "x_num_col_dims": src_op.desc.attr("x_num_col_dims"), - "y_num_col_dims": src_op.desc.attr("y_num_col_dims"), - OP_ROLE_KEY: src_op.attr('op_role'), - } - inputs = {'X': X_var, 'Y': Weight_var} - # infer out var shape with op dist attr out_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(Out_var) assert out_tensor_dist_attr is not None out_var_dist_attr = op_dist_attr.get_output_dist_attr(Out_var.name) assert out_var_dist_attr is not None - ref_shape = infer_shape( - main_block, Out_var, out_tensor_dist_attr, out_var_dist_attr - ) - - intermediate_var_0 = main_block.create_var( - name=unique_name.generate_with_ignorable_key( - ".".join(["c_allreduce_sum", 'tmp']) - ), - shape=Out_var.shape, - dtype=Out_var.dtype, - type=Out_var.type, - lod_level=Out_var.lod_level, - persistable=False, - is_data=False, - need_check_feed=Out_var.desc.need_check_feed(), - ) - # set intermediate_var_0's dist_attr with Out_var's dist_attr - ctx.set_tensor_dist_attr_for_program( - intermediate_var_0, out_var_dist_attr - ) - - inputs_ref_shape = {} - inputs_original_shape = {} - for var_name in inputs: - var = inputs[var_name] - inputs_original_shape[var_name] = var.shape - input_tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) - input_var_dist_attr = op_dist_attr.get_input_dist_attr(var.name) - input_ref_shape = infer_shape( - main_block, var, input_tensor_dist_attr, input_var_dist_attr - ) - inputs_ref_shape[var_name] = input_ref_shape - var.desc.set_shape(input_ref_shape) - - mul_op = main_block.append_op( - type='mul', - inputs=inputs, - outputs={'Out': intermediate_var_0}, - attrs=attrs, - ) - - if intermediate_var_0.shape != ref_shape: - intermediate_var_0.desc.set_shape(ref_shape) - for var_name in inputs: - var = inputs[var_name] - original_shape = inputs_original_shape[var_name] - var.desc.set_shape(original_shape) + # copy op + mul_op = copy_op_without_infer_shape(src_op, main_block, ctx, kwargs) c_allreduce_sum_op = main_block.append_op( type='c_allreduce_sum', - inputs={'X': intermediate_var_0}, + inputs={'X': Out_var}, outputs={'Out': Out_var}, attrs={ 'ring_id': group.id, @@ -2899,9 +2656,6 @@ def forward(ctx, *args, **kwargs): }, ) - if Out_var.shape != ref_shape: - Out_var.desc.set_shape(ref_shape) - # set dist op's dist_attr with serial op's dist_attr # matmulv2 matmulv2_op_dist_attr = OperatorDistAttr() diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer.py b/python/paddle/distributed/auto_parallel/static/parallelizer.py index 6e4eecf89fc5e8..06d2f4a995b750 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer.py @@ -45,7 +45,7 @@ get_world_process_group, ) from .reshard import Resharder -from .utils import SerialProgramInfo, make_data_unshard, set_grad_var_shape +from .utils import SerialProgramInfo, make_data_unshard _logger = get_logger(logging.INFO) @@ -260,8 +260,6 @@ def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): dist_main_prog, dist_startup_prog, dist_params_grads ) - set_grad_var_shape(dist_main_prog, self._dist_context) - make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) resharder = Resharder( diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 38b9ae8dcda59a..6f0a1db1a3bff9 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -26,12 +26,7 @@ from .partitioner import Partitioner from .process_group import get_world_process_group from .reshard import Resharder -from .utils import ( - get_pp_stage, - is_sequential_run, - set_grad_var_shape, - use_new_executor, -) +from .utils import get_pp_stage, is_sequential_run, use_new_executor class Parallelizer: @@ -122,7 +117,7 @@ def parallel(self, rank, parameter_list=None): time.time() - time0, self._mode ) ) - set_grad_var_shape(dist_main_prog, self._dist_context) + resharder = Resharder( dist_main_prog, dist_startup_prog, diff --git a/python/paddle/distributed/auto_parallel/static/process_group.py b/python/paddle/distributed/auto_parallel/static/process_group.py index df881be1a31e3a..6bf7b18cabcb0f 100644 --- a/python/paddle/distributed/auto_parallel/static/process_group.py +++ b/python/paddle/distributed/auto_parallel/static/process_group.py @@ -13,7 +13,6 @@ # limitations under the License import hashlib -import os from collections import OrderedDict import paddle @@ -158,10 +157,10 @@ def instantiate(self): strategy.nrings = 1 if core.is_compiled_with_cuda(): place = core.CUDAPlace(genv.device_id) - use_new_comm = os.getenv( - "FLAGS_dynamic_static_unified_comm", "0" - ) - if use_new_comm in ["1", "True", "true"]: + use_new_comm = paddle.get_flags( + "FLAGS_dynamic_static_unified_comm" + )["FLAGS_dynamic_static_unified_comm"] + if use_new_comm: store = core.create_or_get_global_tcp_store() endpoints_str = "" for endpoint in strategy.trainer_endpoints: diff --git a/python/paddle/distributed/auto_parallel/static/reshard.py b/python/paddle/distributed/auto_parallel/static/reshard.py index cf1ed597536e30..9cc1a61610d808 100644 --- a/python/paddle/distributed/auto_parallel/static/reshard.py +++ b/python/paddle/distributed/auto_parallel/static/reshard.py @@ -33,7 +33,7 @@ from .dist_attribute import TensorDistAttr from .dist_context import DistributedContext from .process_group import new_process_group -from .utils import is_gradient_clip_op +from .utils import is_gradient_clip_op, is_optimize_op # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] @@ -1786,6 +1786,17 @@ def parse_op_desc( source_tensor = get_var_with_recursion( var_name, block, self.auto_parallel_main_prog ) + + def is_grad(name): + return name.endswith('GRAD') + + # all op that generate grad is marked as OpRole.Backward + op_role = ( + OpRole.Backward + if is_optimize_op(reshard_op) and is_grad(var_name) + else reshard_op.attr('op_role') + ) + for op_desc in op_desc_list: if isinstance(op_desc, AllGatherOpDesc): if var_name not in self.has_allgather.keys(): @@ -1799,7 +1810,7 @@ def parse_op_desc( block, idx, source_tensor, - reshard_op.attr('op_role'), + op_role, paddle.int64, ) tensor_list, idx_offset = Inserter.insert_allgather_op( @@ -1807,7 +1818,7 @@ def parse_op_desc( idx + 1, out_cast, op_desc.group, - reshard_op.attr('op_role'), + op_role, ) idx += idx_offset tensor_name_list = [] @@ -1816,7 +1827,7 @@ def parse_op_desc( block, idx, var, - reshard_op.attr('op_role'), + op_role, paddle.bool, ) tensor_name_list.append(out_cast.name) @@ -1830,7 +1841,7 @@ def parse_op_desc( idx, source_tensor, op_desc.group, - reshard_op.attr('op_role'), + op_role, ) idx += idx_offset tensor_name_list = [var.name for var in tensor_list] @@ -1862,7 +1873,7 @@ def parse_op_desc( block, idx, source_tensor, - reshard_op.attr('op_role'), + op_role, paddle.int64, ) Inserter.insert_send_op( @@ -1871,7 +1882,7 @@ def parse_op_desc( out_cast, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) idx += 2 else: @@ -1881,7 +1892,7 @@ def parse_op_desc( source_tensor, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) idx += 1 self.has_sent[var_name].append(op_desc.dst) @@ -1909,13 +1920,13 @@ def parse_op_desc( recv_tensor, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) out_cast = Inserter.insert_cast_op( block, idx + 1, recv_tensor, - reshard_op.attr('op_role'), + op_role, paddle.bool, ) tensor_list.append(out_cast) @@ -1935,7 +1946,7 @@ def parse_op_desc( recv_tensor, op_desc.src, op_desc.dst, - reshard_op.attr('op_role'), + op_role, ) # for lod tensor, need reset lod after received @@ -1958,7 +1969,7 @@ def parse_op_desc( idx + 1, recv_tensor, tmp_var, - reshard_op.attr('op_role'), + op_role, ) ) tensor_list.append(reset_lod_out) @@ -1988,7 +1999,7 @@ def parse_op_desc( partition_index_list[index], block, idx_list, - reshard_op.attr('op_role'), + op_role, ) idx = idx_list[0] @@ -2013,7 +2024,7 @@ def parse_op_desc( ends=op_desc.ends, axes=op_desc.axes, new_var_name=new_name, - op_role=reshard_op.attr('op_role'), + op_role=op_role, ) else: target_tensor = Inserter.insert_c_concat_op( @@ -2021,7 +2032,7 @@ def parse_op_desc( idx, source_tensor, op_desc.group, - reshard_op.attr('op_role'), + op_role, ) assert target_tensor is not None diff --git a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py index 6a3365eff018b4..6f1c26e5f235c4 100644 --- a/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/static/tuner/optimization_tuner.py @@ -38,10 +38,7 @@ new_process_group, ) from paddle.distributed.auto_parallel.static.reshard import Resharder -from paddle.distributed.auto_parallel.static.utils import ( - debug_program, - set_grad_var_shape, -) +from paddle.distributed.auto_parallel.static.utils import debug_program from paddle.distributed.passes import PassContext, new_pass from paddle.static import append_backward, program_guard from paddle.utils import unique_name @@ -353,8 +350,6 @@ def _apply_optimization(self, trial): ) completer.complete_update_annotation(dist_main_prog) - # Do reshard process - set_grad_var_shape(dist_main_prog, dist_context) resharder = Resharder( dist_main_prog, dist_startup_prog, diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index fac4df3d451446..da57e126058a56 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -1205,149 +1205,6 @@ def _get_split_indices( return split_indices_list -def set_grad_var_shape(program, dist_context): - from paddle.distributed.fleet.meta_optimizers.common import OpRole - - from .operators.common import infer_shape - - block = program.global_block() - vars = block.vars - appended_grad_times = 0 - grad_var_to_var = dist_context.dist_op_context.grad_var_to_var - - for idx, op in enumerate(block.ops): - if int(op.attr('op_role')) != int(OpRole.Backward): - continue - - if ( - int(block.ops[idx - 1].attr('op_role')) == int(OpRole.Forward) - or int(block.ops[idx - 1].attr('op_role')) == 257 - ): - appended_grad_times += 1 - - if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: - break - - if op.type in ["sum", "concat", "shape"]: - continue - - op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - assert op_dist_attr is not None - - for var_name in op.output_arg_names: - if "@GRAD" not in var_name: - continue - if var_name in grad_var_to_var[appended_grad_times]: - forward_var_name = grad_var_to_var[appended_grad_times][ - var_name - ] - else: - forward_var_name = var_name[: var_name.find("@GRAD")] - - if op.type in [ - "c_allreduce_sum", - "c_identity", - "scale", - "cast", - "fill_any_like", - ]: - forward_var_name = op.input_arg_names[0] - elif ( - op.type == "matmul_v2_grad" - or op.type == "matmul_grad" - or op.type == "mul_grad" - ): - forward_var_name = None - for output_name in op.output_names: - if var_name in op.output(output_name): - assert "@GRAD" in output_name - input_name = output_name[: output_name.find("@GRAD")] - assert len(op.input(input_name)) == 1 - forward_var_name = op.input(input_name)[0] - assert forward_var_name is not None - - need_set_shape_list = [ - "reshape2_grad", - "softmax_with_cross_entropy_grad", - "transpose2_grad", - "softmax_grad", - "cross_entropy_grad2", - "dropout_grad", - "tanh_grad", - "slice", - "assign", - "matmul_v2_triple_grad", - "elementwise_add_triple_grad", - "fill_constant", - "sqrt_grad", - "fused_softmax_mask_upper_triangle_grad", - "flatten_contiguous_range_grad", - "relu_grad", - "exp_grad", - "sigmoid_grad", - "unsqueeze2_grad", - "fused_dropout_add_grad", - ] - forward_list = [ - "reshape2", - "softmax_with_cross_entropy", - "transpose2", - "softmax", - "cross_entropy2", - "dropout", - "tanh", - ["slice_grad", "c_allgather"], - "assign", - "matmul_v2_grad_grad", - "elementwise_add_grad_grad", - "shape", - "sqrt", - "fused_softmax_mask_upper_triangle", - "flatten_contiguous_range", - "relu", - "exp", - "sigmoid", - "unsqueeze2", - "fused_dropout_add", - ] - if op.type in need_set_shape_list: - for forward_op in block.ops: - idx = need_set_shape_list.index(op.type) - forward_op_name = forward_list[idx] - if ( - forward_op.type in forward_op_name - and forward_var_name in forward_op.input_arg_names - ): - op_dist_attr = ( - dist_context.get_op_dist_attr_for_program( - forward_op - ) - ) - break - - forward_input_dist_attr = op_dist_attr.get_input_dist_attr( - forward_var_name - ) - assert ( - forward_input_dist_attr is not None - ), f"{forward_var_name, str(op)}" - forward_var = vars[forward_var_name] - forward_var_dist_attr = ( - dist_context.get_tensor_dist_attr_for_program(forward_var) - ) - assert forward_var_dist_attr is not None - grad_var = vars[var_name] - ref_shape = infer_shape( - block, - forward_var, - forward_var_dist_attr, - forward_input_dist_attr, - ) - - if list(grad_var.shape) != ref_shape: - grad_var.desc.set_shape(ref_shape) - - def is_forward_op(op): op_role = int(op.attr('op_role')) return OP_ROLE_KEY in op.attr_names and ( diff --git a/python/paddle/distributed/fleet/base/private_helper_function.py b/python/paddle/distributed/fleet/base/private_helper_function.py index c5199eb46a7475..0da733c0f24c65 100644 --- a/python/paddle/distributed/fleet/base/private_helper_function.py +++ b/python/paddle/distributed/fleet/base/private_helper_function.py @@ -16,6 +16,8 @@ import time from contextlib import closing +import paddle + __all__ = [] @@ -33,6 +35,15 @@ def wait_server_ready(endpoints): >>> wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"]) """ + try: + use_new_comm = paddle.get_flags("FLAGS_dynamic_static_unified_comm")[ + "FLAGS_dynamic_static_unified_comm" + ] + except: + use_new_comm = False + + if use_new_comm: + return assert not isinstance(endpoints, str) while True: all_ok = True diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index 5e90584b25b5e7..f18f7aeb068761 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -105,54 +105,55 @@ class Fleet: Returns: Fleet: A Fleet instance - + Examples: .. code-block:: python :name: code-example1 - # Example1: for collective training - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet + >>> # Example1: for collective training + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet - fleet.init(is_collective=True) + >>> fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + >>> strategy = fleet.DistributedStrategy() + >>> linear = paddle.nn.Linear(10, 10) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters()) + >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - # do distributed training + >>> # do distributed training .. code-block:: python :name: code-example2 - # Example2: for parameter server training - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - fleet.init(strategy=strategy) + >>> # Example2: for parameter server training + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet + >>> strategy = fleet.DistributedStrategy() + >>> fleet.init(strategy=strategy) - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001) + >>> optimizer = fleet.distributed_optimizer(optimizer) - if fleet.is_first_worker(): - print("this is first worker") + >>> if fleet.is_first_worker(): + ... print("this is first worker") - print("current node index: {}".format(fleet.worker_index())) - print("total number of worker num: {}".format(fleet.worker_num())) + >>> print("current node index: {}".format(fleet.worker_index())) + >>> print("total number of worker num: {}".format(fleet.worker_num())) - if fleet.is_worker(): - print("this is worker") - print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True))) + >>> if fleet.is_worker(): + ... print("this is worker") + >>> print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True))) - print("server num: {}".format(fleet.server_num())) - print("server endpoints: {}".format(fleet.server_endpoints(to_string=True))) + >>> print("server num: {}".format(fleet.server_num())) + >>> print("server endpoints: {}".format(fleet.server_endpoints(to_string=True))) - if fleet.is_server(): - print("this is server") - fleet.stop_worker() + >>> if fleet.is_server(): + ... print("this is server") + >>> fleet.stop_worker() """ @@ -202,37 +203,37 @@ def init( .. code-block:: python :name: code-example1 - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() .. code-block:: python :name: code-example2 - import paddle.distributed.fleet as fleet - fleet.init(is_collective=True) + >>> import paddle.distributed.fleet as fleet + >>> fleet.init(is_collective=True) .. code-block:: python :name: code-example3 - import paddle.distributed.fleet as fleet - role = fleet.PaddleCloudRoleMaker() - fleet.init(role) + >>> import paddle.distributed.fleet as fleet + >>> role = fleet.PaddleCloudRoleMaker() + >>> fleet.init(role) .. code-block:: python :name: code-example4 - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - fleet.init(strategy=strategy) + >>> import paddle.distributed.fleet as fleet + >>> strategy = fleet.DistributedStrategy() + >>> fleet.init(strategy=strategy) .. code-block:: python :name: code-example5 - import paddle.distributed.fleet as fleet - strategy = fleet.DistributedStrategy() - fleet.init(log_level = "DEBUG") + >>> import paddle.distributed.fleet as fleet + >>> strategy = fleet.DistributedStrategy() + >>> fleet.init(log_level = "DEBUG") """ from paddle.distributed import parallel_helper @@ -454,9 +455,9 @@ def is_first_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.is_first_worker() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.is_first_worker() """ return self._role_maker._is_first_worker() @@ -472,9 +473,9 @@ def worker_index(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.worker_index() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.worker_index() """ return self._role_maker._worker_index() @@ -490,9 +491,9 @@ def worker_num(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.worker_num() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.worker_num() """ return self._role_maker._worker_num() @@ -521,9 +522,9 @@ def is_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.is_worker() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.is_worker() """ return self._role_maker._is_worker() @@ -542,9 +543,9 @@ def worker_endpoints(self, to_string=False): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.worker_endpoints() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.worker_endpoints() """ if to_string: @@ -563,9 +564,9 @@ def server_num(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.server_num() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.server_num() """ return len(self._role_maker._get_pserver_endpoints()) @@ -580,9 +581,9 @@ def server_index(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.server_index() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.server_index() """ return self._role_maker._server_index() @@ -598,9 +599,9 @@ def server_endpoints(self, to_string=False): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.server_endpoints() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.server_endpoints() """ @@ -621,9 +622,9 @@ def is_server(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.is_server() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.is_server() """ return self._role_maker._is_server() @@ -639,9 +640,9 @@ def barrier_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - fleet.barrier_worker() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> fleet.barrier_worker() """ self._role_maker._barrier("worker") @@ -659,13 +660,13 @@ def init_worker(self, scopes=None): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_worker() + >>> fleet.init_worker() """ self._runtime_handle._init_worker(scopes) @@ -704,13 +705,13 @@ def init_server(self, *args, **kwargs): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_server() + >>> fleet.init_server() """ self._runtime_handle._init_server(*args, **kwargs) @@ -729,13 +730,13 @@ def load_model(self, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.load_model("path", mode=0) + >>> fleet.load_model("path", mode=0) """ self._runtime_handle._load_persistables(path, mode) @@ -754,13 +755,13 @@ def load_one_table(self, table_id, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.load_one_table(0, "path", mode=0) + >>> fleet.load_one_table(0, "path", mode=0) """ self._runtime_handle._load_one_table(table_id, path, mode) @@ -779,13 +780,13 @@ def load_inference_model(self, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.load_inference_model("path", mode=1) + >>> fleet.load_inference_model("path", mode=1) """ self._runtime_handle._load_inference_model(path, mode) @@ -803,14 +804,14 @@ def run_server(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - if fleet.is_server(): - fleet.init_server() + >>> if fleet.is_server(): + ... fleet.init_server() """ self._runtime_handle._run_server() @@ -828,13 +829,13 @@ def stop_worker(self): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_server() + >>> fleet.init_server() """ self._runtime_handle._stop_worker() @@ -908,13 +909,13 @@ def save_inference_model( .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.init_server() + >>> fleet.init_server() """ @@ -958,17 +959,17 @@ def save_persistables(self, executor, dirname, main_program=None, mode=0): .. code-block:: text - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet - fleet.init() + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - exe = paddle.static.Executor(paddle.CPUPlace()) - fleet.save_persistables(exe, "dirname", paddle.static.default_main_program()) + >>> exe = paddle.static.Executor(paddle.CPUPlace()) + >>> fleet.save_persistables(exe, "dirname", paddle.static.default_main_program()) """ self._runtime_handle._save_persistables( @@ -1008,13 +1009,13 @@ def save_one_table(self, table_id, path, mode): .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.save_one_table(0, "path", mode=0) + >>> fleet.save_one_table(0, "path", mode=0) """ self._runtime_handle._save_one_table(table_id, path, mode) @@ -1035,16 +1036,16 @@ def save_dense_params( .. code-block:: python - import paddle.distributed.fleet as fleet - fleet.init() - import paddle - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) + >>> import paddle.distributed.fleet as fleet + >>> fleet.init() + >>> import paddle + >>> place = paddle.CPUPlace() + >>> exe = paddle.static.Executor(place) - # build net - # fleet.distributed_optimizer(...) + >>> # build net + >>> # fleet.distributed_optimizer(...) - fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program()) + >>> fleet.save_dense_params(exe, "path", scope=paddle.static.global_scope(), program=paddle.static.default_main_program()) """ self._runtime_handle._save_dense_params( @@ -1078,12 +1079,13 @@ def distributed_optimizer(self, optimizer, strategy=None): .. code-block:: python - import paddle - import paddle.distributed.fleet as fleet - fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + >>> import paddle + >>> import paddle.distributed.fleet as fleet + >>> fleet.init(is_collective=True) + >>> linear = paddle.nn.Linear(10, 10) + >>> strategy = fleet.DistributedStrategy() + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters()) + >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) """ self.user_defined_optimizer = optimizer @@ -1141,46 +1143,46 @@ def amp_init( Examples: .. code-block:: python - import paddle - import paddle.nn.functional as F - paddle.enable_static() - - def run_example_code(): - place = paddle.CUDAPlace(0) - exe = paddle.static.Executor(place) - data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') - conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) - # 1) Use fp16_guard to control the range of fp16 kernels used. - with paddle.static.amp.fp16_guard(): - bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") - pool = F.max_pool2d(bn, kernel_size=2, stride=2) - hidden = paddle.static.nn.fc(pool, size=10) - loss = paddle.mean(hidden) - # 2) Create the optimizer and set `multi_precision` to True. - # Setting `multi_precision` to True can avoid the poor accuracy - # or the slow convergence in a way. - optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) - # 3) These ops in `custom_black_list` will keep in the float32 computation type. - amp_list = paddle.static.amp.CustomOpLists( - custom_black_list=['pool2d']) - # 4) The entry of Paddle AMP. - # Enable pure fp16 training by setting `use_pure_fp16` to True. - optimizer = paddle.static.amp.decorate( - optimizer, - amp_list, - init_loss_scaling=128.0, - use_dynamic_loss_scaling=True, - use_pure_fp16=True) - # If you don't use the default_startup_program(), you sholud pass - # your defined `startup_program` into `minimize`. - optimizer.minimize(loss) - exe.run(paddle.static.default_startup_program()) - # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). - # If you want to perform the testing process, you should pass `test_program` into `amp_init`. - optimizer.amp_init(place, scope=paddle.static.global_scope()) - - if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: - run_example_code() + >>> import paddle + >>> import paddle.nn.functional as F + >>> paddle.enable_static() + + >>> def run_example_code(): + ... place = paddle.CUDAPlace(0) + ... exe = paddle.static.Executor(place) + ... data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + ... conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + ... # 1) Use fp16_guard to control the range of fp16 kernels used. + ... with paddle.static.amp.fp16_guard(): + ... bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + ... pool = F.max_pool2d(bn, kernel_size=2, stride=2) + ... hidden = paddle.static.nn.fc(pool, size=10) + ... loss = paddle.mean(hidden) + ... # 2) Create the optimizer and set `multi_precision` to True. + ... # Setting `multi_precision` to True can avoid the poor accuracy + ... # or the slow convergence in a way. + ... optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) + ... # 3) These ops in `custom_black_list` will keep in the float32 computation type. + ... amp_list = paddle.static.amp.CustomOpLists( + ... custom_black_list=['pool2d']) + ... # 4) The entry of Paddle AMP. + ... # Enable pure fp16 training by setting `use_pure_fp16` to True. + ... optimizer = paddle.static.amp.decorate( + ... optimizer, + ... amp_list, + ... init_loss_scaling=128.0, + ... use_dynamic_loss_scaling=True, + ... use_pure_fp16=True) + ... # If you don't use the default_startup_program(), you sholud pass + ... # your defined `startup_program` into `minimize`. + ... optimizer.minimize(loss) + ... exe.run(paddle.static.default_startup_program()) + ... # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). + ... # If you want to perform the testing process, you should pass `test_program` into `amp_init`. + ... optimizer.amp_init(place, scope=paddle.static.global_scope()) + + >>> if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + ... run_example_code() """ amp_optimizer = self._get_amp_optimizer() return amp_optimizer.amp_init(place, scope, test_program, use_fp16_test) @@ -1273,28 +1275,29 @@ def minimize( .. code-block:: python - import paddle - paddle.enable_static() - import paddle.distributed.fleet as fleet - import paddle.nn.functional as F - - hid_dim = 10 - label_dim = 2 - input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32') - input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64') - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh') - fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh') - prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax') - cost = F.cross_entropy(input=prediction, label=input_y) - avg_cost = paddle.mean(x=cost) - - fleet.init(is_collective=True) - strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX + >>> import paddle + >>> paddle.enable_static() + >>> import paddle.distributed.fleet as fleet + >>> import paddle.nn.functional as F + + >>> hid_dim = 10 + >>> label_dim = 2 + >>> input_x = paddle.static.data(name='x', shape=[None, 13], dtype='float32') + >>> input_y = paddle.static.data(name='y', shape=[None, 1], dtype='int64') + >>> fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh') + >>> fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh') + >>> prediction = paddle.static.nn.fc(x=[fc_2], size=label_dim, activation='softmax') + >>> cost = F.cross_entropy(input=prediction, label=input_y) + >>> avg_cost = paddle.mean(x=cost) + + >>> fleet.init(is_collective=True) + >>> strategy = fleet.DistributedStrategy() + >>> linear = paddle.nn.Linear(10, 10) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=linear.parameters()) + >>> optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + >>> optimizer.minimize(avg_cost) + + >>> # for more examples, please reference https://github.com/PaddlePaddle/PaddleFleetX """ if not isinstance(loss, list): diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 7b59a6d5946403..d7febc350a5b52 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import paddle from paddle.autograd import PyLayer from paddle.base import core @@ -20,6 +22,7 @@ from ....communication.reduce import ReduceOp, _get_reduce_op from ...base import topology as tp +from ...utils.log_util import logger from . import mp_ops from .random import get_rng_state_tracker @@ -177,6 +180,9 @@ def forward(self, x): return output +_raise_cuda_env_unset_warning = True + + class InnerOverlapLinear(paddle.autograd.PyLayer): @staticmethod def forward( @@ -216,8 +222,17 @@ def backward(ctx, dy): task = ctx.model_parallel_group.process_group.all_reduce( dx, op_type, sync_op=False ) - # TODO(GhostScreaming): remove it in future. - tmp = paddle.ones([512]) + # Using small operation to preempt GPU SMs for all_reduce to achieve overlap. + if int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1: + global _raise_cuda_env_unset_warning + if _raise_cuda_env_unset_warning: + logger.warning( + "You set mp_async_allreduce=True, but you forget to set environment " + "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance " + "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance." + ) + _raise_cuda_env_unset_warning = False + tmp = paddle.ones([512]) if ctx.mp_fused_linear_param_grad_add: if not is_fused_linear_param_grad_add_supported(): @@ -263,7 +278,7 @@ def backward(ctx, dy): weight.main_grad, bias.main_grad, ) = paddle._C_ops.fused_linear_param_grad_add( - input, + x, dy, weight.main_grad, bias.main_grad, @@ -293,9 +308,10 @@ def backward(ctx, dy): task.wait() return dx, dw, dbias else: + dy = dy.reshape([-1, dy.shape[-1]]) dw = paddle.matmul( x.reshape([-1, x.shape[-1]]), - dy.reshape([-1, dy.shape[-1]]), + dy, transpose_x=True, ) if bias is None: diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 9625e2481d4002..75be5f621d4124 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import paddle from paddle.framework import core @@ -99,8 +98,10 @@ def _init_communicator( other_endpoints.remove(current_endpoint) if rank == 0 and wait_port: - use_new_comm = os.getenv("FLAGS_dynamic_static_unified_comm", "0") - if use_new_comm not in [1, "1", "True", "true"]: + use_new_comm = paddle.get_flags( + "FLAGS_dynamic_static_unified_comm" + )["FLAGS_dynamic_static_unified_comm"] + if not use_new_comm: wait_server_ready(other_endpoints) def _add_sync_by_allreduce(block): diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1ee99b10854b9f..ab8ec3a67b145f 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -14,6 +14,7 @@ import os +import paddle from paddle.base import core from paddle.incubate.optimizer import PipelineOptimizer from paddle.static import ( @@ -714,8 +715,10 @@ def minimize_impl( self._recreate_not_persist_param_as_var() self._dump_program_for_debug() - use_new_comm = os.getenv("FLAGS_dynamic_static_unified_comm", "0") - if use_new_comm not in ["1", "True", "true"]: + use_new_comm = paddle.get_flags("FLAGS_dynamic_static_unified_comm")[ + "FLAGS_dynamic_static_unified_comm" + ] + if not use_new_comm: self._wait() return optimize_ops, params_grads diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index c4e0d54b99a12d..44b799b00c91d7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -382,7 +382,12 @@ def scale(grad): if hasattr(param, "main_grad"): param.main_grad.scale_(self._world_size_scaling) else: - grad.scale_(self._world_size_scaling) + if grad is not None and grad._is_initialized(): + grad.scale_(self._world_size_scaling) + else: + assert param.grad is not None + assert param.grad._is_initialized() + param.grad.scale_(self._world_size_scaling) return scale diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index e541e76e634fc6..ff9ff2ee2a9c3e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -48,8 +48,10 @@ def _prepare_for_model(self): logger.info("mp's parameters is ready") def _pre_forward(self, *inputs, **kwargs): - mp_configs = self._strategy.hybrid_configs["mp_configs"] - need_broadcast_data = mp_configs.need_broadcast_data + need_broadcast_data = True + if self._strategy is not None: + mp_configs = self._strategy.hybrid_configs["mp_configs"] + need_broadcast_data = mp_configs.need_broadcast_data if need_broadcast_data: logger.debug("mp start broadcast input data") return broadcast_input_data(self._hcg, *inputs, **kwargs) diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 080b56ac6478d7..6a8202965d5be2 100644 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -423,87 +423,94 @@ def recompute(function, *args, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.distributed.fleet.utils import recompute - import random - # required: gpu - def get_fc_block(block_idx, input_size, is_last=False): - block_name = "block_" + str(block_idx) - block = paddle.nn.Sequential( - (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)), - (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), - (block_name + "_relu_1", paddle.nn.ReLU()), - (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)), - (block_name + "_relu_2", paddle.nn.ReLU()), - ) - if is_last: - block.add_sublayer( - block_name + "_fc_2", - paddle.nn.Linear( - input_size, 1, bias_attr=False - ) - ) - else: - block.add_sublayer( - block_name + "_fc_2", - paddle.nn.Linear(input_size, input_size, bias_attr=False) - ) - return block - class Naive_fc_net(paddle.nn.Layer): - def __init__(self, input_size=10, - recompute_blocks=[1, 3], - recompute_kwargs={}): - super().__init__() - self.recompute_blocks = recompute_blocks - self.recompute_kwargs = recompute_kwargs - self.runfunc0 = get_fc_block(0, input_size, is_last=False) - self.runfunc1 = get_fc_block(1, input_size, is_last=False) - self.runfunc2 = get_fc_block(2, input_size, is_last=False) - self.runfunc3 = get_fc_block(3, input_size, is_last=False) - self.runfunc4 = get_fc_block(4, input_size, is_last=True) - self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] - def forward(self, inputs): - nums = len(self.total_func) - for i in range(nums): - if i in self.recompute_blocks: - inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True}) - else: - inputs = self.total_func[i](inputs) - return inputs - def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): - gen = paddle.seed(10) - gen.manual_seed(10) - random.seed(10) - if cuda_state: - paddle.set_cuda_rng_state(cuda_state) - batch_size, input_size = 1, 10 - model = Naive_fc_net( - input_size, - recompute_blocks=recompute_block, - recompute_kwargs=recompute_kwargs) - optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) - loss_ = [] - param_ = [] - grad_ = [] - for _ in range(5): - x = paddle.rand(shape=[batch_size, input_size], dtype="float32") - y_pred = model(x) - loss = y_pred.mean() - loss_.append(loss.item()) - loss.backward() - optimizer.step() - param_.append(model.parameters()[9]) - grad_.append(model.parameters()[3]._grad_ivar()) - optimizer.clear_grad() - return loss_, param_, grad_ - cuda_state = paddle.get_cuda_rng_state() - # without recompute - loss_ref, param_ref, grad_ref = run_model( - cuda_state, recompute_block=[] - ) - loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) - print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) - # The result of the recompute_loss should be the same as the normal_loss. + >>> # doctest: +REQUIRES(env:DISTRIBUTED, env:GPU) + >>> import paddle + >>> from paddle.distributed.fleet.utils import recompute + >>> import random + >>> paddle.seed(2023) + >>> def get_fc_block(block_idx, input_size, is_last=False): + ... block_name = "block_" + str(block_idx) + ... block = paddle.nn.Sequential( + ... (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)), + ... (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + ... (block_name + "_relu_1", paddle.nn.ReLU()), + ... (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)), + ... (block_name + "_relu_2", paddle.nn.ReLU()), + ... ) + ... if is_last: + ... block.add_sublayer( + ... block_name + "_fc_2", + ... paddle.nn.Linear( + ... input_size, 1, bias_attr=False + ... ) + ... ) + ... else: + ... block.add_sublayer( + ... block_name + "_fc_2", + ... paddle.nn.Linear(input_size, input_size, bias_attr=False) + ... ) + ... return block + + >>> class Naive_fc_net(paddle.nn.Layer): + ... def __init__(self, input_size=10, + ... recompute_blocks=[1, 3], + ... recompute_kwargs={}): + ... super().__init__() + ... self.recompute_blocks = recompute_blocks + ... self.recompute_kwargs = recompute_kwargs + ... self.runfunc0 = get_fc_block(0, input_size, is_last=False) + ... self.runfunc1 = get_fc_block(1, input_size, is_last=False) + ... self.runfunc2 = get_fc_block(2, input_size, is_last=False) + ... self.runfunc3 = get_fc_block(3, input_size, is_last=False) + ... self.runfunc4 = get_fc_block(4, input_size, is_last=True) + ... self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] + ... def forward(self, inputs): + ... nums = len(self.total_func) + ... for i in range(nums): + ... if i in self.recompute_blocks: + ... inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True}) + ... else: + ... inputs = self.total_func[i](inputs) + ... return inputs + + >>> def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): + ... gen = paddle.seed(10) + ... gen.manual_seed(10) + ... random.seed(10) + ... if cuda_state: + ... paddle.set_cuda_rng_state(cuda_state) + ... batch_size, input_size = 1, 10 + ... model = Naive_fc_net( + ... input_size, + ... recompute_blocks=recompute_block, + ... recompute_kwargs=recompute_kwargs) + ... optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + ... loss_ = [] + ... param_ = [] + ... grad_ = [] + ... for _ in range(5): + ... x = paddle.rand(shape=[batch_size, input_size], dtype="float32") + ... y_pred = model(x) + ... loss = y_pred.mean() + ... loss_.append(loss.item()) + ... loss.backward() + ... optimizer.step() + ... param_.append(model.parameters()[9]) + ... grad_.append(model.parameters()[3]._grad_ivar()) + ... optimizer.clear_grad() + ... return loss_, param_, grad_ + + >>> cuda_state = paddle.get_cuda_rng_state() + >>> # without recompute + >>> loss_ref, param_ref, grad_ref = run_model( + ... cuda_state, recompute_block=[] + ... ) + + >>> loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) + >>> print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) + >>> # The result of the recompute_loss should be the same as the normal_loss. + normal_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0], recompute_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0] + """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) @@ -544,11 +551,14 @@ def recompute_sequential(ctx, functions, *args, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.incubate.distributed.fleet import recompute_sequential - input = paddle.ones(shape=[8, 10]) - model = paddle.nn.Sequential(paddle.nn.Linear(10, 10), paddle.nn.Linear(10, 2)) - output = recompute_sequential({'segments' : 1}, model, input) + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import paddle + >>> from paddle.incubate.distributed.fleet import recompute_sequential + >>> input = paddle.ones(shape=[8, 10]) + >>> model = paddle.nn.Sequential(paddle.nn.Linear(10, 10), paddle.nn.Linear(10, 2)) + >>> output = recompute_sequential({'segments' : 1}, model, input) + """ segments = ctx.get('segments', 1) preserve_rng_state = ctx.get('preserve_rng_state', True) diff --git a/python/paddle/distributed/fleet/scaler.py b/python/paddle/distributed/fleet/scaler.py index bf0d7363b05251..463674c9587413 100755 --- a/python/paddle/distributed/fleet/scaler.py +++ b/python/paddle/distributed/fleet/scaler.py @@ -29,23 +29,32 @@ def distributed_scaler(scaler): def unscale_method(self, optimizer): if not self._enable: return + + param_grads = [] + param_grads_fp16 = [] + param_grads_fp32 = [] if getattr(optimizer, '_param_groups', None) and isinstance( optimizer._param_groups[0], dict ): - param_grads = [] - param_grads_fp16 = [] - param_grads_fp32 = [] for group in optimizer._param_groups: for param in group['params']: - if param._grad_ivar() is not None: - param_grads.append(param._grad_ivar()) - if ( - param._grad_ivar().dtype - == core.VarDesc.VarType.FP16 - ): - param_grads_fp16.append(param._grad_ivar()) + tgt_grad = None + if ( + hasattr(param, "main_grad") + and param.main_grad is not None + ): + tgt_grad = param.main_grad + elif param.grad is not None: + tgt_grad = param.grad + if tgt_grad is not None: + param_grads.append(tgt_grad) + if tgt_grad.dtype in [ + core.VarDesc.VarType.FP16, + paddle.float16, + ]: + param_grads_fp16.append(tgt_grad) else: - param_grads_fp32.append(param._grad_ivar()) + param_grads_fp32.append(tgt_grad) else: strategy = fleet.fleet._user_defined_strategy sharding_stage_1_overlap = strategy.hybrid_configs[ @@ -67,18 +76,23 @@ def unscale_method(self, optimizer): parameters = optimizer._local_parameter_list else: parameters = optimizer._parameter_list - param_grads_fp16 = [ - param._grad_ivar() - for param in parameters - if (param._grad_ivar() is not None) - and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16) - ] - param_grads_fp32 = [ - param._grad_ivar() - for param in parameters - if (param._grad_ivar() is not None) - and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32) - ] + + for param in parameters: + tgt_grad = None + if hasattr(param, "main_grad") and param.main_grad is not None: + tgt_grad = param.main_grad + elif param.grad is not None: + tgt_grad = param.grad + if tgt_grad is not None: + param_grads.append(tgt_grad) + if tgt_grad.dtype in [ + core.VarDesc.VarType.FP16, + paddle.float16, + ]: + param_grads_fp16.append(tgt_grad) + else: + param_grads_fp32.append(tgt_grad) + temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) self._found_inf = self._temp_found_inf_value_false diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 3b2472b50b46d4..3a751b5d0c3c89 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -51,87 +51,94 @@ def recompute(function, *args, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.distributed.fleet.utils import recompute - import random - # required: gpu - def get_fc_block(block_idx, input_size, is_last=False): - block_name = "block_" + str(block_idx) - block = paddle.nn.Sequential( - (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)), - (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), - (block_name + "_relu_1", paddle.nn.ReLU()), - (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)), - (block_name + "_relu_2", paddle.nn.ReLU()), - ) - if is_last: - block.add_sublayer( - block_name + "_fc_2", - paddle.nn.Linear( - input_size, 1, bias_attr=False - ) - ) - else: - block.add_sublayer( - block_name + "_fc_2", - paddle.nn.Linear(input_size, input_size, bias_attr=False) - ) - return block - class Naive_fc_net(paddle.nn.Layer): - def __init__(self, input_size=10, - recompute_blocks=[1, 3], - recompute_kwargs={}): - super().__init__() - self.recompute_blocks = recompute_blocks - self.recompute_kwargs = recompute_kwargs - self.runfunc0 = get_fc_block(0, input_size, is_last=False) - self.runfunc1 = get_fc_block(1, input_size, is_last=False) - self.runfunc2 = get_fc_block(2, input_size, is_last=False) - self.runfunc3 = get_fc_block(3, input_size, is_last=False) - self.runfunc4 = get_fc_block(4, input_size, is_last=True) - self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] - def forward(self, inputs): - nums = len(self.total_func) - for i in range(nums): - if i in self.recompute_blocks: - inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True}) - else: - inputs = self.total_func[i](inputs) - return inputs - def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): - gen = paddle.seed(10) - gen.manual_seed(10) - random.seed(10) - if cuda_state: - paddle.set_cuda_rng_state(cuda_state) - batch_size, input_size = 1, 10 - model = Naive_fc_net( - input_size, - recompute_blocks=recompute_block, - recompute_kwargs=recompute_kwargs) - optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) - loss_ = [] - param_ = [] - grad_ = [] - for _ in range(5): - x = paddle.rand(shape=[batch_size, input_size], dtype="float32") - y_pred = model(x) - loss = y_pred.mean() - loss_.append(loss.item()) - loss.backward() - optimizer.step() - param_.append(model.parameters()[9]) - grad_.append(model.parameters()[3]._grad_ivar()) - optimizer.clear_grad() - return loss_, param_, grad_ - cuda_state = paddle.get_cuda_rng_state() - # without recompute - loss_ref, param_ref, grad_ref = run_model( - cuda_state, recompute_block=[] - ) - loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) - print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) - # The result of the recompute_loss should be the same as the normal_loss. + >>> # doctest: +REQUIRES(env:DISTRIBUTED, env:GPU) + >>> import paddle + >>> from paddle.distributed.fleet.utils import recompute + >>> import random + >>> paddle.seed(2023) + >>> def get_fc_block(block_idx, input_size, is_last=False): + ... block_name = "block_" + str(block_idx) + ... block = paddle.nn.Sequential( + ... (block_name + "_fc_0", paddle.nn.Linear(input_size, input_size, bias_attr=False)), + ... (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + ... (block_name + "_relu_1", paddle.nn.ReLU()), + ... (block_name + "_fc_1", paddle.nn.Linear(input_size, input_size, bias_attr=False)), + ... (block_name + "_relu_2", paddle.nn.ReLU()), + ... ) + ... if is_last: + ... block.add_sublayer( + ... block_name + "_fc_2", + ... paddle.nn.Linear( + ... input_size, 1, bias_attr=False + ... ) + ... ) + ... else: + ... block.add_sublayer( + ... block_name + "_fc_2", + ... paddle.nn.Linear(input_size, input_size, bias_attr=False) + ... ) + ... return block + + >>> class Naive_fc_net(paddle.nn.Layer): + ... def __init__(self, input_size=10, + ... recompute_blocks=[1, 3], + ... recompute_kwargs={}): + ... super().__init__() + ... self.recompute_blocks = recompute_blocks + ... self.recompute_kwargs = recompute_kwargs + ... self.runfunc0 = get_fc_block(0, input_size, is_last=False) + ... self.runfunc1 = get_fc_block(1, input_size, is_last=False) + ... self.runfunc2 = get_fc_block(2, input_size, is_last=False) + ... self.runfunc3 = get_fc_block(3, input_size, is_last=False) + ... self.runfunc4 = get_fc_block(4, input_size, is_last=True) + ... self.total_func = [self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4] + ... def forward(self, inputs): + ... nums = len(self.total_func) + ... for i in range(nums): + ... if i in self.recompute_blocks: + ... inputs = recompute(self.total_func[i], inputs, **{"preserve_rng_state": True}) + ... else: + ... inputs = self.total_func[i](inputs) + ... return inputs + + >>> def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): + ... gen = paddle.seed(10) + ... gen.manual_seed(10) + ... random.seed(10) + ... if cuda_state: + ... paddle.set_cuda_rng_state(cuda_state) + ... batch_size, input_size = 1, 10 + ... model = Naive_fc_net( + ... input_size, + ... recompute_blocks=recompute_block, + ... recompute_kwargs=recompute_kwargs) + ... optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + ... loss_ = [] + ... param_ = [] + ... grad_ = [] + ... for _ in range(5): + ... x = paddle.rand(shape=[batch_size, input_size], dtype="float32") + ... y_pred = model(x) + ... loss = y_pred.mean() + ... loss_.append(loss.item()) + ... loss.backward() + ... optimizer.step() + ... param_.append(model.parameters()[9]) + ... grad_.append(model.parameters()[3]._grad_ivar()) + ... optimizer.clear_grad() + ... return loss_, param_, grad_ + + >>> cuda_state = paddle.get_cuda_rng_state() + >>> # without recompute + >>> loss_ref, param_ref, grad_ref = run_model( + ... cuda_state, recompute_block=[] + ... ) + + >>> loss, param, grad = run_model(cuda_state, recompute_block=[1, 2]) + >>> print("normal_loss: {}, recompute_loss: {}".format(loss_ref, loss)) + >>> # The result of the recompute_loss should be the same as the normal_loss. + normal_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0], recompute_loss: [0.0018744759727269411, 0.0, 0.035971127450466156, 0.0, 0.0] + """ return fleet.recompute.recompute(function, *args, **kwargs) diff --git a/python/paddle/distributed/fleet/utils/fs.py b/python/paddle/distributed/fleet/utils/fs.py index 11617981d9d4b0..743ceac3e296cc 100644 --- a/python/paddle/distributed/fleet/utils/fs.py +++ b/python/paddle/distributed/fleet/utils/fs.py @@ -117,10 +117,12 @@ class LocalFS(FS): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> subdirs, files = client.ls_dir("./") - client = LocalFS() - subdirs, files = client.ls_dir("./") """ def ls_dir(self, fs_path): @@ -137,10 +139,12 @@ def ls_dir(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> subdirs, files = client.ls_dir("./") - client = LocalFS() - subdirs, files = client.ls_dir("./") """ if not self.is_exist(fs_path): return [], [] @@ -165,11 +169,13 @@ def mkdirs(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.mkdirs("test_mkdirs") + >>> client.delete("test_mkdirs") - client = LocalFS() - client.mkdirs("test_mkdirs") - client.delete("test_mkdirs") """ assert not os.path.isfile(fs_path), f"{fs_path} is already a file" os.makedirs(fs_path, exist_ok=True) @@ -185,15 +191,20 @@ def rename(self, fs_src_path, fs_dst_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.touch("test_rename_src") + >>> print(client.is_exist("test_rename_src")) + True + >>> client.rename("test_rename_src", "test_rename_dst") + >>> print(client.is_exist("test_rename_src")) + False + >>> print(client.is_exist("test_rename_dst")) + True + >>> client.delete("test_rename_dst") - client = LocalFS() - client.touch("test_rename_src") - print(client.is_exists("test_rename_src")) # True - client.rename("test_rename_src", "test_rename_dst") - print(client.is_exists("test_rename_src")) # False - print(client.is_exists("test_rename_dst")) # True - client.delete("test_rename_dst") """ os.rename(fs_src_path, fs_dst_path) @@ -213,11 +224,13 @@ def delete(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.mkdirs("test_localFS_mkdirs") + >>> client.delete("test_localFS_mkdirs") - client = LocalFS() - client.mkdirs("test_localFS_mkdirs") - client.delete("test_localFS_mkdirs") """ if not self.is_exist(fs_path): return @@ -243,12 +256,15 @@ def is_file(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.touch("test_is_file") + >>> print(client.is_file("test_is_file")) + True + >>> client.delete("test_is_file") - client = LocalFS() - client.touch("test_is_file") - print(client.is_file("test_is_file")) # True - client.delete("test_is_file") """ return os.path.isfile(fs_path) @@ -265,12 +281,15 @@ def is_dir(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.mkdirs("test_is_dir") + >>> print(client.is_dir("test_is_dir")) + True + >>> client.delete("test_is_dir") - client = LocalFS() - client.mkdirs("test_is_dir") - print(client.is_dir("test_is_file")) # True - client.delete("test_is_dir") """ return os.path.isdir(fs_path) @@ -288,10 +307,12 @@ def is_exist(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> local_fs = LocalFS() + >>> ret = local_fs.is_exist("test_is_exist") - client = LocalFS() - ret = local_fs.is_exist("test_is_exist") """ return os.path.exists(fs_path) @@ -307,11 +328,13 @@ def touch(self, fs_path, exist_ok=True): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.touch("test_touch") + >>> client.delete("test_touch") - client = LocalFS() - client.touch("test_touch") - client.delete("test_touch") """ if self.is_exist(fs_path): if exist_ok: @@ -332,12 +355,14 @@ def mv(self, src_path, dst_path, overwrite=False, test_exists=False): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> client.touch("test_mv_src") + >>> client.mv("test_mv_src", "test_mv_dst") + >>> client.delete("test_mv_dst") - client = LocalFS() - client.touch("test_mv_src") - client.mv("test_mv_src", "test_mv_dst") - client.delete("test_mv_dst") """ if not self.is_exist(src_path): raise FSFileNotExistsError @@ -363,10 +388,12 @@ def list_dirs(self, fs_path): Examples: .. code-block:: python - from paddle.distributed.fleet.utils import LocalFS + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import LocalFS + + >>> client = LocalFS() + >>> subdirs = client.list_dirs("./") - client = LocalFS() - subdirs = client.list_dirs("./") """ if not self.is_exist(fs_path): return [] @@ -428,18 +455,21 @@ class HDFSClient(FS): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" - from paddle.distributed.fleet.utils import HDFSClient - hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.ls_dir("hdfs:/test_hdfs_client") + ([], []) - client = HDFSClient(hadoop_home, configs) - client.ls_dir("hdfs:/test_hdfs_client") """ def __init__( @@ -496,18 +526,20 @@ def list_dirs(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> subdirs = client.list_dirs("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - subdirs = client.list_dirs("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return [] @@ -529,18 +561,20 @@ def ls_dir(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> subdirs, files = client.ls_dir("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - subdirs, files = client.ls_dir("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return [], [] @@ -590,18 +624,20 @@ def is_dir(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> ret = client.is_file("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - ret = client.is_file("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return False @@ -634,18 +670,20 @@ def is_file(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> ret = client.is_file("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - ret = client.is_file("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return False @@ -666,18 +704,20 @@ def is_exist(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DITSTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> ret = client.is_exist("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - ret = client.is_exist("hdfs:/test_hdfs_client") """ cmd = f"test -e {fs_path} " ret, out = self._run_cmd(cmd, redirect_stderr=True, retry_times=1) @@ -718,18 +758,20 @@ def upload(self, local_path, fs_path, multi_processes=5, overwrite=False): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on external file') + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.upload("test_hdfs_client", "hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - client.upload("test_hdfs_client", "hdfs:/test_hdfs_client") """ def __subprocess_upload(hdfs_path_single, datas): @@ -808,18 +850,20 @@ def download(self, fs_path, local_path, multi_processes=5, overwrite=False): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on external file') + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.download("hdfs:/test_hdfs_client", "./") - client = HDFSClient(hadoop_home, configs) - client.download("hdfs:/test_hdfs_client", "./") """ def __subprocess_download(local_path, datas): @@ -877,18 +921,20 @@ def mkdirs(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on external file') + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.mkdirs("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - client.mkdirs("hdfs:/test_hdfs_client") """ if self.is_exist(fs_path): return @@ -923,18 +969,20 @@ def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on external file') + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.mv("hdfs:/test_hdfs_client", "hdfs:/test_hdfs_client2") - client = HDFSClient(hadoop_home, configs) - client.mv("hdfs:/test_hdfs_client", "hdfs:/test_hdfs_client2") """ if overwrite and self.is_exist(fs_dst_path): self.delete(fs_dst_path) @@ -983,18 +1031,20 @@ def delete(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.delete("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - client.delete("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return @@ -1016,18 +1066,20 @@ def touch(self, fs_path, exist_ok=True): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on external file') + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.touch("hdfs:/test_hdfs_client") - client = HDFSClient(hadoop_home, configs) - client.touch("hdfs:/test_hdfs_client") """ if self.is_exist(fs_path): if exist_ok: @@ -1058,18 +1110,21 @@ def cat(self, fs_path=None): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> from paddle.distributed.fleet.utils import HDFSClient - from paddle.distributed.fleet.utils import HDFSClient + >>> hadoop_home = "/home/client/hadoop-client/hadoop/" + >>> configs = { + ... "fs.default.name": "hdfs://xxx.hadoop.com:54310", + ... "hadoop.job.ugi": "hello,hello123" + ... } - hadoop_home = "/home/client/hadoop-client/hadoop/" - configs = { - "fs.default.name": "hdfs://xxx.hadoop.com:54310", - "hadoop.job.ugi": "hello,hello123" - } + >>> client = HDFSClient(hadoop_home, configs) + >>> client.cat("hdfs:/test_hdfs_client") + '' - client = HDFSClient(hadoop_home, configs) - client.cat("hdfs:/test_hdfs_client") """ if self.is_file(fs_path): output = self._try_cat(fs_path) @@ -1151,12 +1206,15 @@ class AFSClient(FS): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient + + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.ls_dir("hdfs:/test_hdfs_client") - from paddle.distributed.fleet.utils import AFSClient - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.ls_dir("hdfs:/test_hdfs_client") """ def __init__(self, time_out=5 * 60 * 1000, sleep_inter=1000): # ms # ms @@ -1178,13 +1236,15 @@ def list_dirs(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> subdirs = client.list_dirs("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - subdirs = client.list_dirs("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return [] @@ -1205,13 +1265,15 @@ def ls_dir(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> subdirs, files = client.ls_dir("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - subdirs, files = client.ls_dir("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return [], [] @@ -1235,13 +1297,15 @@ def is_dir(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> ret = client.is_dir("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - ret = client.is_file("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return False @@ -1267,13 +1331,15 @@ def is_file(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> ret = client.is_file("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - ret = client.is_file("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return False @@ -1293,13 +1359,15 @@ def is_exist(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> ret = client.is_exist("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - ret = client.is_exist("hdfs:/test_hdfs_client") """ return self._fs.exist(fs_path) @@ -1335,13 +1403,15 @@ def upload(self, local_path, fs_path, multi_processes=1, overwrite=False): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.upload("test_hdfs_client", "hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.upload("test_hdfs_client", "hdfs:/test_hdfs_client") """ local = LocalFS() @@ -1362,13 +1432,15 @@ def download(self, fs_path, local_path, multi_processes=1, overwrite=False): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.download("hdfs:/test_hdfs_client", "./") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.download("hdfs:/test_hdfs_client", "./") """ def __subprocess_download(local_path, datas): @@ -1411,13 +1483,15 @@ def mkdirs(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.mkdirs("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.mkdirs("hdfs:/test_hdfs_client") """ if self.is_exist(fs_path): return @@ -1435,13 +1509,15 @@ def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.mv("hdfs:/test_hdfs_client", "hdfs:/test_hdfs_client2") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.mv("hdfs:/test_hdfs_client", "hdfs:/test_hdfs_client2") """ if overwrite and self.is_exist(fs_dst_path): self.delete(fs_dst_path) @@ -1464,15 +1540,16 @@ def delete(self, fs_path): Examples: - .. code-block:: text + .. code-block:: python + - from paddle.distributed.fleet.utils import HDFSClient + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.delete("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.delete("hdfs:/test_hdfs_client") """ if not self.is_exist(fs_path): return @@ -1489,13 +1566,15 @@ def touch(self, fs_path, exist_ok=True): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.touch("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.touch("hdfs:/test_hdfs_client") """ if self.is_exist(fs_path): if exist_ok: @@ -1519,13 +1598,15 @@ def cat(self, fs_path=None): Examples: - .. code-block:: text + .. code-block:: python + + >>> # doctest: +SKIP('depend on WITH_PSLIB') + >>> from paddle.distributed.fleet.utils.fs import AFSClient - from paddle.distributed.fleet.utils import AFSClient + >>> client = AFSClient() + >>> client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") + >>> client.cat("hdfs:/test_hdfs_client") - client = AFSClient() - client.init("hdfs://xxx.hadoop.com:54310", "hello", "hello123", "./fs_conf") - client.cat("hdfs:/test_hdfs_client") """ if self.is_file(fs_path): return self._fs.cat(fs_path) diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py index 9c44fc49fff672..9170754bb78ff8 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_inference.py @@ -42,145 +42,123 @@ class HybridParallelInferenceHelper: Write Paradigm: - .. code-block:: bash - :name: bash-example1 - - # while op pattern - with paddle.base.device_guard(f'{device}:all'): - # init global cond - max_len = paddle.full(shape=[1], dtype="int64", fill_value=10) - step_idx = paddle.full(shape=[1], dtype="int64", fill_value=0) - cond_int = paddle.full(shape=[1], dtype="int64", fill_value=0, name="cond_int") - cond = layers.cast(step_idx < max_len, dtype="bool") - while_op = layers.While(cond, is_test=True) - - # init global lod_tensor_array for generation task - arr = paddle.tensor.array_write(data, step_idx) - - with while_op.block(): - with paddle.base.device_guard(f'{device}:all'): - # read data from global lod_tensor_array - element_in_arr = paddle.tensor.array_read(array=arr, i=step_idx) - # write placehold data to global lod_tensor_array, - # it need for send_v2 of lod_tensor_array - paddle.increment(x=step_idx, value=1.0) - paddle.tensor.array_write(element_in_arr, i=step_idx, array=arr) - - with paddle.base.device_guard(f'{device}:0'): - ... some code - - with paddle.base.device_guard(f'{device}:1'): - ... some code - - with paddle.base.device_guard(f'{device}:{num_pp-1}'): - # generate some data in while block and write to global lod_tensor_array - # that they are read in next while step. - # we will using send_v2 to send global lod_tensor_array to other pipeline and sync - paddle.tensor.array_write(other_var, i=step_idx, array=arr) - - # update cond and assign to cond_int, we will sync cond_int - layers.assign(layers.cast(cond, dtype="int32"), cond_int) - - with paddle.base.device_guard(f'{model._device}:all'): - # the code below must at end of while block and exists in device:all - layers.assign(layers.cast(cond_int, dtype='bool'), cond) - - with paddle.base.device_guard(f'{model._device}:all'): - # use a empty lod_tensor_array to clear lod_tensor_array - layers.assign(layers.create_array(data.dtype), arr) - + .. code-block:: text + :name: text-example1 + + >>> # doctest: +REQUIRES(env:DISTRIBUTED, env:GPU) + >>> import paddle + >>> # while op pattern + >>> with paddle.base.device_guard(f'{device}:all'): + ... # init global cond + ... max_len = paddle.full(shape=[1], dtype="int64", fill_value=10) + ... step_idx = paddle.full(shape=[1], dtype="int64", fill_value=0) + ... cond_int = paddle.full(shape=[1], dtype="int64", fill_value=0, name="cond_int") + ... cond = layers.cast(step_idx < max_len, dtype="bool") + ... while_op = layers.While(cond, is_test=True) + + ... # init global lod_tensor_array for generation task + ... arr = paddle.tensor.array_write(data, step_idx) + + >>> with while_op.block(): + ... with paddle.base.device_guard(f'{device}:all'): + ... # read data from global lod_tensor_array + ... element_in_arr = paddle.tensor.array_read(array=arr, i=step_idx) + ... # write placehold data to global lod_tensor_array, + ... # it need for send_v2 of lod_tensor_array + ... paddle.increment(x=step_idx, value=1.0) + ... paddle.tensor.array_write(element_in_arr, i=step_idx, array=arr) + ... with paddle.base.device_guard(f'{device}:0'): + ... pass # some code + ... with paddle.base.device_guard(f'{device}:1'): + ... pass # some code + ... with paddle.base.device_guard(f'{device}:{num_pp-1}'): + ... # generate some data in while block and write to global lod_tensor_array + ... # that they are read in next while step. + ... # we will using send_v2 to send global lod_tensor_array to other pipeline and sync + ... paddle.tensor.array_write(other_var, i=step_idx, array=arr) + ... # update cond and assign to cond_int, we will sync cond_int + ... layers.assign(layers.cast(cond, dtype="int32"), cond_int) + ... with paddle.base.device_guard(f'{model._device}:all'): + ... # the code below must at end of while block and exists in device:all + ... layers.assign(layers.cast(cond_int, dtype='bool'), cond) + + >>> with paddle.base.device_guard(f'{model._device}:all'): + ... # use a empty lod_tensor_array to clear lod_tensor_array + ... layers.assign(layers.create_array(data.dtype), arr) Examples: - .. code-block:: python - :name: code-example1 - - # required: distributed - import os - import numpy as np - import paddle - import paddle.base.layers as layers - import paddle.distributed.fleet as fleet - paddle.enable_static() - - nranks = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) - rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) - dev_id = int(os.getenv("FLAGS_selected_gpus", 0)) - - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - - if nranks > 1: - dist_strategy = fleet.DistributedStrategy() - dist_strategy.without_graph_optimization = True - fleet.init(is_collective=True, strategy=dist_strategy) - - device = "gpu" - - with paddle.static.program_guard(main_program, startup_program): - with paddle.base.device_guard(f'{device}:0'): - X = paddle.static.data(name='X', shape=[None, 2], dtype='float32') - - with paddle.base.device_guard(f'{device}:all'): - max_len = paddle.full( - shape=[1], dtype="int64", fill_value=5, name="n") - step_idx = paddle.full( - shape=[1], dtype="int64", fill_value=0, name="i") - - data = paddle.tensor.array_write(X, step_idx) - - cond_int = paddle.full(shape=[1], dtype="int64", fill_value=0, name="cond_int") - cond = paddle.less_than(x=step_idx, y=max_len) - while_op = layers.While(cond, is_test=True) - - with while_op.block(): - with paddle.base.device_guard(f'{device}:all'): - input = paddle.tensor.array_read(array=data, i=step_idx) - paddle.increment(x=step_idx, value=1.0) - paddle.tensor.array_write(input, i=step_idx, array=data) - - with paddle.base.device_guard(f'{device}:0'): - param_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.0)) - weight1 = paddle.static.create_parameter( - shape=[2, 5], dtype='float32', attr=param_attr, is_bias=False) - hidden1 = paddle.matmul(input, weight1) - - with paddle.base.device_guard(f'{device}:1'): - param_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(2.0)) - weight2 = paddle.static.create_parameter( - shape=[5, 2], dtype='float32', attr=param_attr, is_bias=False) - hidden2 = paddle.matmul(hidden1, weight2) - - paddle.tensor.array_write(hidden2, i=step_idx, array=data) - - # update cond and assign to cond_int, we will sync cond_int - paddle.assign(paddle.less_than(x=step_idx, y=max_len), cond) - layers.assign(layers.cast(cond, dtype="int32"), cond_int) - - with paddle.base.device_guard(f'{device}:all'): - # the code below must at end of while block and exists in device:all - layers.assign(layers.cast(cond_int, dtype='bool'), cond) - - with paddle.base.device_guard(f'{device}:all'): - out = layers.create_array(data.dtype) - layers.assign(data, out) - - with paddle.base.device_guard(f'{device}:all'): - # use a empty lod_tensor_array to clear lod_tensor_array - layers.assign(layers.create_array(data.dtype), data) - - helper = fleet.HybridParallelInferenceHelper(startup_program, main_program, micro_batch_size=2, num_pp=2, init_comm=nranks>1) - helper.gen_infer_program(['array_write_0.out'], ['cond_int.tmp_0']) - - exe = paddle.static.Executor(paddle.CUDAPlace(dev_id)) - exe.run(startup_program) - - np.random.seed(2333) - for step in range(5): - init_data = np.random.uniform(low=0.0, high=1.0, size=[2, 2]).astype('float32') - [res] = exe.run(main_program, feed={"X": init_data}, fetch_list=[out]) - print('-------- step', step, ' --------') - print(res) + .. code-block:: python + :name: code-example1 + + >>> # doctest: +REQUIRES(env:DISTRIBUTED, env:GPU) + >>> import os + >>> import numpy as np + >>> import paddle + >>> import paddle.distributed.fleet as fleet + >>> from paddle.distributed.fleet.utils import hybrid_parallel_inference + >>> paddle.enable_static() + >>> nranks = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) + >>> rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) + >>> dev_id = int(os.getenv("FLAGS_selected_gpus", 0)) + >>> main_program = paddle.static.Program() + >>> startup_program = paddle.static.Program() + >>> if nranks > 1: + ... dist_strategy = fleet.DistributedStrategy() + ... dist_strategy.without_graph_optimization = True + ... fleet.init(is_collective=True, strategy=dist_strategy) + >>> device = "gpu" + >>> with paddle.static.program_guard(main_program, startup_program): + ... with paddle.base.device_guard(f'{device}:0'): + ... X = paddle.static.data(name='X', shape=[None, 2], dtype='float32') + ... with paddle.base.device_guard(f'{device}:all'): + ... max_len = paddle.full( + ... shape=[1], dtype="int64", fill_value=5, name="n") + ... step_idx = paddle.full( + ... shape=[1], dtype="int64", fill_value=0, name="i") + ... data = paddle.tensor.array_write(X, step_idx) + ... cond_int = paddle.full(shape=[1], dtype="int64", fill_value=0, name="cond_int") + ... cond = paddle.less_than(x=step_idx, y=max_len) + ... while_op = paddle.static.nn.control_flow.While(cond, is_test=True) + ... with while_op.block(): + ... with paddle.base.device_guard(f'{device}:all'): + ... input = paddle.tensor.array_read(array=data, i=step_idx) + ... paddle.increment(x=step_idx, value=1.0) + ... paddle.tensor.array_write(input, i=step_idx, array=data) + ... with paddle.base.device_guard(f'{device}:0'): + ... param_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.0)) + ... weight1 = paddle.static.create_parameter( + ... shape=[2, 5], dtype='float32', attr=param_attr, is_bias=False) + ... hidden1 = paddle.matmul(input, weight1) + ... with paddle.base.device_guard(f'{device}:1'): + ... param_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(2.0)) + ... weight2 = paddle.static.create_parameter( + ... shape=[5, 2], dtype='float32', attr=param_attr, is_bias=False) + ... hidden2 = paddle.matmul(hidden1, weight2) + ... paddle.tensor.array_write(hidden2, i=step_idx, array=data) + ... # update cond and assign to cond_int, we will sync cond_int + ... paddle.assign(paddle.less_than(x=step_idx, y=max_len), cond) + ... paddle.assign(paddle.cast(cond, dtype="int32"), cond_int) + ... with paddle.base.device_guard(f'{device}:all'): + ... # the code below must at end of while block and exists in device:all + ... paddle.assign(paddle.cast(cond_int, dtype='bool'), cond) + ... with paddle.base.device_guard(f'{device}:all'): + ... out = paddle.tensor.create_array(data.dtype) + ... paddle.assign(data, out) + ... with paddle.base.device_guard(f'{device}:all'): + ... # use a empty lod_tensor_array to clear lod_tensor_array + ... paddle.assign(paddle.tensor.create_array(data.dtype), data) + >>> helper = hybrid_parallel_inference.HybridParallelInferenceHelper(startup_program, main_program, micro_batch_size=2, num_pp=2, init_comm=nranks>1) + >>> helper.gen_infer_program(['array_write_0.out'], ['cond_int.tmp_0']) + >>> exe = paddle.static.Executor(paddle.CUDAPlace(dev_id)) + >>> exe.run(startup_program) + >>> np.random.seed(2333) + >>> for step in range(5): + ... init_data = np.random.uniform(low=0.0, high=1.0, size=[2, 2]).astype('float32') + ... [res] = exe.run(main_program, feed={"X": init_data}, fetch_list=[out]) + ... print('-------- step', step, ' --------') + ... print(res) + """ def __init__( diff --git a/python/paddle/distributed/fleet/utils/mix_precision_utils.py b/python/paddle/distributed/fleet/utils/mix_precision_utils.py index f6b04bbfda011e..e779d41b8f3faa 100644 --- a/python/paddle/distributed/fleet/utils/mix_precision_utils.py +++ b/python/paddle/distributed/fleet/utils/mix_precision_utils.py @@ -28,6 +28,7 @@ obtain_optimizer_parameters_list, ) from paddle.framework import core +from paddle.utils import deprecated class MixPrecisionLayer(nn.Layer): @@ -232,6 +233,11 @@ def unscale_method(self, optimizer): self._found_inf = int(is_found_inf) +@deprecated( + since="2.5.0", + update_to="paddle.distributed_scaler", + level=1, +) class MixPrecisionScaler: def __init__(self, scaler): self._inner_scaler = scaler diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index bec8b72fac52cc..8890ab0bd179ae 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -247,50 +247,43 @@ class DataParallel(layers.Layer): .. code-block:: python :name: dp-example - # required: distributed - import paddle - import paddle.nn as nn - import paddle.optimizer as opt - import paddle.distributed as dist - - class LinearNet(nn.Layer): - def __init__(self): - super().__init__() - self._linear1 = nn.Linear(10, 10) - self._linear2 = nn.Linear(10, 1) - - def forward(self, x): - return self._linear2(self._linear1(x)) - - def train(): - # 1. initialize parallel environment - dist.init_parallel_env() - - # 2. create data parallel layer & optimizer - layer = LinearNet() - dp_layer = paddle.DataParallel(layer) - - loss_fn = nn.MSELoss() - adam = opt.Adam( - learning_rate=0.001, parameters=dp_layer.parameters()) - - # 3. run layer - inputs = paddle.randn([10, 10], 'float32') - outputs = dp_layer(inputs) - labels = paddle.randn([10, 1], 'float32') - loss = loss_fn(outputs, labels) - - loss.backward() - - adam.step() - adam.clear_grad() - - if __name__ == '__main__': - # 1. start by ``paddle.distributed.spawn`` (default) - dist.spawn(train, nprocs=2) - # 2. start by ``paddle.distributed.launch`` - # train() - + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import paddle + >>> import paddle.nn as nn + >>> import paddle.optimizer as opt + >>> import paddle.distributed as dist + + >>> class LinearNet(nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self._linear1 = nn.Linear(10, 10) + ... self._linear2 = nn.Linear(10, 1) + ... def forward(self, x): + ... return self._linear2(self._linear1(x)) + + >>> def train(): + ... # 1. initialize parallel environment + ... dist.init_parallel_env() + ... # 2. create data parallel layer & optimizer + ... layer = LinearNet() + ... dp_layer = paddle.DataParallel(layer) + ... loss_fn = nn.MSELoss() + ... adam = opt.Adam( + ... learning_rate=0.001, parameters=dp_layer.parameters()) + ... # 3. run layer + ... inputs = paddle.randn([10, 10], 'float32') + ... outputs = dp_layer(inputs) + ... labels = paddle.randn([10, 1], 'float32') + ... loss = loss_fn(outputs, labels) + ... loss.backward() + ... adam.step() + ... adam.clear_grad() + + >>> if __name__ == '__main__': + ... # 1. start by ``paddle.distributed.spawn`` (default) + ... dist.spawn(train, nprocs=2) + ... # 2. start by ``paddle.distributed.launch`` + ... # train() .. note:: ``PyLayer`` is not supported in DataParallel. To solve problems of this kind, @@ -303,58 +296,51 @@ def train(): .. code-block:: python :name: dp-pylayer-example - # required: distributed - import numpy - import paddle - import paddle.distributed as dist - from paddle.autograd import PyLayer - from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x): - y = paddle.tanh(x) - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - y, = ctx.saved_tensor() - grad = dy * (1 - paddle.square(y)) - return grad - - class SimpleNet(paddle.nn.Layer): - def __init__(self): - super().__init__() - self.linear = paddle.nn.Linear(2, 2) - - def forward(self, inputs): - inputs = cus_tanh.apply(inputs) - return self.linear(inputs) - - if __name__ == '__main__': - dist.init_parallel_env() - - model = SimpleNet() - model = paddle.DataParallel(model) - opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) - - for step in range(10): - x_data = numpy.random.randn(2, 2).astype(numpy.float32) - x = paddle.to_tensor(x_data) - x.stop_gradient = False - - # step 1 : skip gradient synchronization by 'no_sync' - with model.no_sync(): - y_pred = model(x) - loss = y_pred.mean() - loss.backward() - - # step 2 : fuse + allreduce manually before optimization - fused_allreduce_gradients(list(model.parameters()), None) - - opt.step() - opt.clear_grad() + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import numpy + >>> import paddle + >>> import paddle.distributed as dist + >>> from paddle.autograd import PyLayer + >>> from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients + + >>> class cus_tanh(PyLayer): + ... @staticmethod + ... def forward(ctx, x): + ... y = paddle.tanh(x) + ... ctx.save_for_backward(y) + ... return y + ... @staticmethod + ... def backward(ctx, dy): + ... y, = ctx.saved_tensor() + ... grad = dy * (1 - paddle.square(y)) + ... return grad + + >>> class SimpleNet(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.linear = paddle.nn.Linear(2, 2) + ... def forward(self, inputs): + ... inputs = cus_tanh.apply(inputs) + ... return self.linear(inputs) + + >>> if __name__ == '__main__': + ... dist.init_parallel_env() + ... model = SimpleNet() + ... model = paddle.DataParallel(model) + ... opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters()) + ... for step in range(10): + ... x_data = numpy.random.randn(2, 2).astype(numpy.float32) + ... x = paddle.to_tensor(x_data) + ... x.stop_gradient = False + ... # step 1 : skip gradient synchronization by 'no_sync' + ... with model.no_sync(): + ... y_pred = model(x) + ... loss = y_pred.mean() + ... loss.backward() + ... # step 2 : fuse + allreduce manually before optimization + ... fused_allreduce_gradients(list(model.parameters()), None) + ... opt.step() + ... opt.clear_grad() """ @@ -502,32 +488,31 @@ def no_sync(self): Examples: .. code-block:: python - # required: distributed - import paddle - import paddle.nn as nn - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import paddle + >>> import paddle.nn as nn + >>> import paddle.distributed as dist - class SimpleNet(nn.Layer): - def __init__(self): - super().__init__() - self._linear = nn.Linear(10, 1) + >>> class SimpleNet(nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self._linear = nn.Linear(10, 1) + ... def forward(self, x): + ... return self._linear(x) - def forward(self, x): - return self._linear(x) + >>> dist.init_parallel_env() + >>> model = SimpleNet() + >>> dp_model = paddle.DataParallel(model) - dist.init_parallel_env() - model = SimpleNet() - dp_model = paddle.DataParallel(model) + >>> inputs_1 = paddle.randn([10, 10], 'float32') + >>> inputs_2 = paddle.ones([10, 10], 'float32') - inputs_1 = paddle.randn([10, 10], 'float32') - inputs_2 = paddle.ones([10, 10], 'float32') + >>> with dp_model.no_sync(): + ... # gradients will not be synchronized + ... dp_model(inputs_1).backward() - with dp_model.no_sync(): - # gradients will not be synchronized - dp_model(inputs_1).backward() - - # synchronization happens here - dp_model(inputs_2).backward() + >>> # synchronization happens here + >>> dp_model(inputs_2).backward() """ tmp_grad_need_sync = self.grad_need_sync @@ -586,16 +571,17 @@ def state_dict( Examples: .. code-block:: python - import paddle - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import paddle + >>> import paddle.distributed as dist - dist.init_parallel_env() + >>> dist.init_parallel_env() - emb = paddle.nn.Embedding(10, 10) - emb = paddle.DataParallel(emb) + >>> emb = paddle.nn.Embedding(10, 10) + >>> emb = paddle.DataParallel(emb) - state_dict = emb.state_dict() - paddle.save(state_dict, "paddle_dy.pdparams") + >>> state_dict = emb.state_dict() + >>> paddle.save(state_dict, "paddle_dy.pdparams") ''' @@ -620,19 +606,20 @@ def set_state_dict(self, state_dict, use_structured_name=True): Examples: .. code-block:: python - import paddle - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import paddle + >>> import paddle.distributed as dist - dist.init_parallel_env() + >>> dist.init_parallel_env() - emb = paddle.nn.Embedding(10, 10) - emb = paddle.DataParallel(emb) + >>> emb = paddle.nn.Embedding(10, 10) + >>> emb = paddle.DataParallel(emb) - state_dict = emb.state_dict() - paddle.save(state_dict, "paddle_dy.pdparams") + >>> state_dict = emb.state_dict() + >>> paddle.save(state_dict, "paddle_dy.pdparams") - para_state_dict = paddle.load("paddle_dy.pdparams") - emb.set_state_dict(para_state_dict) + >>> para_state_dict = paddle.load("paddle_dy.pdparams") + >>> emb.set_state_dict(para_state_dict) ''' @@ -664,32 +651,34 @@ class ParallelEnv: or ``paddle.distributed.spawn`` . Examples: - .. code-block:: python - - import paddle - import paddle.distributed as dist - - def train(): - # 1. initialize parallel environment - dist.init_parallel_env() - - # 2. get current ParallelEnv - parallel_env = dist.ParallelEnv() - print("rank: ", parallel_env.rank) - print("world_size: ", parallel_env.world_size) - - # print result in process 1: - # rank: 1 - # world_size: 2 - # print result in process 2: - # rank: 2 - # world_size: 2 - - if __name__ == '__main__': - # 1. start by ``paddle.distributed.spawn`` (default) - dist.spawn(train, nprocs=2) - # 2. start by ``paddle.distributed.launch`` - # train() + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> import paddle + >>> import paddle.distributed as dist + + >>> def train(): + ... # 1. initialize parallel environment + ... dist.init_parallel_env() + ... # 2. get current ParallelEnv + ... parallel_env = dist.ParallelEnv() + ... print("rank: ", parallel_env.rank) + ... print("world_size: ", parallel_env.world_size) + + >>> if __name__ == '__main__': + ... # 1. start by ``paddle.distributed.spawn`` (default) + ... dist.spawn(train, nprocs=2) + ... # 2. start by ``paddle.distributed.launch`` + ... train() + + # Print result in process 1: + rank: 1 + world_size: 2 + + # Print result in process 2: + rank: 2 + world_size: 2 + """ def __init__(self): @@ -734,14 +723,16 @@ def rank(self): Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` . The default value is 0. Examples: - .. code-block:: python + .. code-block:: python - # execute this command in terminal: export PADDLE_TRAINER_ID=0 - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # execute this command in terminal: export PADDLE_TRAINER_ID=0 + >>> import paddle.distributed as dist + + >>> env = dist.ParallelEnv() + >>> print("The rank is %d" % env.rank) + The rank is 0 - env = dist.ParallelEnv() - print("The rank is %d" % env.rank) - # The rank is 0 """ return self._rank @@ -753,14 +744,16 @@ def world_size(self): Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` . The default value is 1. Examples: - .. code-block:: python + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # execute this command in terminal: export PADDLE_TRAINERS_NUM=4 + >>> import paddle.distributed as dist - # execute this command in terminal: export PADDLE_TRAINERS_NUM=4 - import paddle.distributed as dist + >>> env = dist.ParallelEnv() + >>> print("The world_size is %d" % env.world_size) + The world_size is 4 - env = dist.ParallelEnv() - print("The world_size is %d" % env.world_size) - # The world_size is 4 """ return self._world_size @@ -772,14 +765,15 @@ def device_id(self): Its value is equal to the value of the environment variable ``FLAGS_selected_gpus`` . The default value is 0. Examples: - .. code-block:: python + .. code-block:: python - # execute this command in terminal: export FLAGS_selected_gpus=1 - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # execute this command in terminal: export FLAGS_selected_gpus=1 + >>> import paddle.distributed as dist - env = dist.ParallelEnv() - print("The device id are %d" % env.device_id) - # The device id are 1 + >>> env = dist.ParallelEnv() + >>> print("The device id are %d" % env.device_id) + The device id are 1 """ return self._device_id @@ -801,14 +795,15 @@ def current_endpoint(self): Its value is equal to the value of the environment variable ``PADDLE_CURRENT_ENDPOINT`` . The default value is "". Examples: - .. code-block:: python + .. code-block:: python - # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170 - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # execute this command in terminal: export PADDLE_CURRENT_ENDPOINT=127.0.0.1:6170 + >>> import paddle.distributed as dist - env = dist.ParallelEnv() - print("The current endpoint are %s" % env.current_endpoint) - # The current endpoint are 127.0.0.1:6170 + >>> env = dist.ParallelEnv() + >>> print("The current endpoint are %s" % env.current_endpoint) + The current endpoint are 127.0.0.1:6170 """ return self._current_endpoint @@ -821,14 +816,16 @@ def trainer_endpoints(self): Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ENDPOINTS`` . The default value is "". Examples: - .. code-block:: python + .. code-block:: python + + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171 + >>> import paddle.distributed as dist - # execute this command in terminal: export PADDLE_TRAINER_ENDPOINTS=127.0.0.1:6170,127.0.0.1:6171 - import paddle.distributed as dist + >>> env = dist.ParallelEnv() + >>> print("The trainer endpoints are %s" % env.trainer_endpoints) + The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171'] - env = dist.ParallelEnv() - print("The trainer endpoints are %s" % env.trainer_endpoints) - # The trainer endpoints are ['127.0.0.1:6170', '127.0.0.1:6171'] """ return self._trainer_endpoints @@ -840,14 +837,15 @@ def nrings(self): Its value is equal to the value of the environment variable ``FLAGS_nccl_nrings`` . The default value is 1. Examples: - .. code-block:: python + .. code-block:: python - # execute this command in terminal: export FLAGS_nccl_nrings=1 - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # execute this command in terminal: export FLAGS_nccl_nrings=1 + >>> import paddle.distributed as dist - env = dist.ParallelEnv() - print("The nrings is %d" % env.nrings) - # the number of ring is 1 + >>> env = dist.ParallelEnv() + >>> print("The nrings is %d" % env.nrings) + The nrings is 1 """ return self._nrings @@ -940,46 +938,40 @@ def init_parallel_env(): Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.nn as nn - import paddle.optimizer as opt - import paddle.distributed as dist - - class LinearNet(nn.Layer): - def __init__(self): - super().__init__() - self._linear1 = nn.Linear(10, 10) - self._linear2 = nn.Linear(10, 1) - - def forward(self, x): - return self._linear2(self._linear1(x)) - - def train(): - # 1. initialize parallel environment - dist.init_parallel_env() - - # 2. create data parallel layer & optimizer - layer = LinearNet() - dp_layer = paddle.DataParallel(layer) - - loss_fn = nn.MSELoss() - adam = opt.Adam( - learning_rate=0.001, parameters=dp_layer.parameters()) - - # 3. run layer - inputs = paddle.randn([10, 10], 'float32') - outputs = dp_layer(inputs) - labels = paddle.randn([10, 1], 'float32') - loss = loss_fn(outputs, labels) - - loss.backward() - - adam.step() - adam.clear_grad() - - if __name__ == '__main__': - dist.spawn(train) + >>> # doctest: +REQUIRES(env:GPU, env:DISTRIBUTED) + >>> import paddle + >>> import paddle.nn as nn + >>> import paddle.optimizer as opt + >>> import paddle.distributed as dist + + >>> class LinearNet(nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self._linear1 = nn.Linear(10, 10) + ... self._linear2 = nn.Linear(10, 1) + ... def forward(self, x): + ... return self._linear2(self._linear1(x)) + + >>> def train(): + ... # 1. initialize parallel environment + ... dist.init_parallel_env() + ... # 2. create data parallel layer & optimizer + ... layer = LinearNet() + ... dp_layer = paddle.DataParallel(layer) + ... loss_fn = nn.MSELoss() + ... adam = opt.Adam( + ... learning_rate=0.001, parameters=dp_layer.parameters()) + ... # 3. run layer + ... inputs = paddle.randn([10, 10], 'float32') + ... outputs = dp_layer(inputs) + ... labels = paddle.randn([10, 1], 'float32') + ... loss = loss_fn(outputs, labels) + ... loss.backward() + ... adam.step() + ... adam.clear_grad() + + >>> if __name__ == '__main__': + ... dist.spawn(train) """ @@ -1213,13 +1205,15 @@ def get_rank(group=None): Examples: .. code-block:: python - # Execute this script using distributed launch with one card configs. - import paddle - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # Execute this script using distributed launch with one card configs. + >>> import paddle + >>> import paddle.distributed as dist + + >>> dist.init_parallel_env() + >>> print("The rank is %d" % dist.get_rank()) + The rank is 0 - dist.init_parallel_env() - print("The rank is %d" % dist.get_rank()) - # The rank is 0 """ if in_dynamic_mode() and group: return group.rank @@ -1245,13 +1239,15 @@ def get_world_size(group=None): Examples: .. code-block:: python - # Execute this script using distributed launch with one card configs. - import paddle - import paddle.distributed as dist + >>> # doctest: +REQUIRES(env:DISTRIBUTED) + >>> # Execute this script using distributed launch with one card configs. + >>> import paddle + >>> import paddle.distributed as dist + + >>> dist.init_parallel_env() + >>> print("The world_size is %d" % dist.get_world_size()) + The world_size is 1 - dist.init_parallel_env() - print("The world_size is %d" % dist.get_world_size()) - # The world_size is 1 """ if in_dynamic_mode() and (group is None): if is_initialized(): diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index f804b59a2db2c6..e0d726b2957422 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -48,8 +48,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context): removed_op_idx.append(idx) # del op from dist_context - if dist_context: - dist_context.del_dist_op_for_program(op) + # if dist_context: + # dist_context.del_dist_op_for_program(op) for idx in removed_op_idx[::-1]: main_block._remove_op(idx, sync=False) @@ -229,6 +229,7 @@ def _create_cond_block_and_update_optimizer( optimize_ops_block, k_steps, avg, + dist_context, ): def true_apply_gradient(): cur_block_idx = main_program.current_block_idx @@ -285,6 +286,14 @@ def true_apply_gradient(): main_program.global_block()._sync_with_cpp() cur_block._sync_with_cpp() + # update serial op + for idx, op in enumerate(cur_block.ops): + if is_optimize_op(op): + dist_op = dist_context.get_dist_op_for_program(op) + if dist_op: + # dist_op.set_input_dist_attr + dist_op._serial_op = op + # clear gradient_merge_vars for param, new_grad in new_params_to_grads: paddle.tensor.fill_constant( @@ -331,6 +340,7 @@ def parse_program( optimize_ops_block, k_steps, avg, + dist_context, ) diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index f7b211fdc4ba41..6c3ee4d8d8e951 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -1691,11 +1691,10 @@ def re_order_program(block, param_grads, dist_context): if is_optimize_op(last_op) and last_op.type in _supported_optimizer_type: # record optimizer for idx, op in reversed(list(enumerate(block.ops))): - if op.type not in _supported_optimizer_type: - break - assert len(op.input("Param")) == 1 - pname_to_op[op.input("Param")[0]] = op - remove_op_indices.append(idx) + if op.type in _supported_optimizer_type: + assert len(op.input("Param")) == 1 + pname_to_op[op.input("Param")[0]] = op + remove_op_indices.append(idx) assert len(use_order) == len(pname_to_op) # append new opts diff --git a/python/paddle/distribution/bernoulli.py b/python/paddle/distribution/bernoulli.py index 7d4849fab48e7c..152306aea31f7c 100644 --- a/python/paddle/distribution/bernoulli.py +++ b/python/paddle/distribution/bernoulli.py @@ -212,6 +212,7 @@ def rsample(self, shape, temperature=1.0): .. code-block:: python >>> import paddle + >>> paddle.seed(1) >>> from paddle.distribution import Bernoulli >>> rv = Bernoulli(paddle.full((1), 0.3)) @@ -231,28 +232,26 @@ def rsample(self, shape, temperature=1.0): [100, 2, 2] >>> # `rsample` has to be followed by a `sigmoid` - >>> # doctest: +SKIP >>> rv = Bernoulli(0.3) >>> rsample = rv.rsample([3, ]) >>> rsample_sigmoid = paddle.nn.functional.sigmoid(rsample) - >>> print(rsample, rsample_sigmoid) - Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, - [[-0.88315082], - [-0.62347704], - [-0.31513220]]) - Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, - [[0.29252526], - [0.34899110], - [0.42186251]]) + >>> print(rsample) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [-1.46112013, -0.01239836, -1.32765460]) + >>> print(rsample_sigmoid) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.18829606, 0.49690047, 0.20954758]) >>> # The smaller the `temperature`, the distribution of `rsample` closer to `sample`, with `probs` of 0.3. >>> print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=1.0)).sum()) + >>> # doctest: +SKIP('output will be different') Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 361.06829834) + 365.63122559) + >>> # doctest: -SKIP >>> print(paddle.nn.functional.sigmoid(rv.rsample([1000, ], temperature=0.1)).sum()) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, - 288.66418457) + 320.15057373) """ name = self.name + '_rsample' if not in_dynamic_mode(): diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index b6484e3f21d563..9d5664dc28f4d3 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -64,14 +64,12 @@ class Categorical(distribution.Distribution): >>> cat = Categorical(x) >>> cat2 = Categorical(y) - >>> # doctest: +SKIP >>> paddle.seed(1000) # on CPU device >>> print(cat.sample([2,3])) Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, [[0, 1, 5], [3, 4, 5]]) - >>> # doctest: -SKIP >>> print(cat.entropy()) Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 1.77528250) diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 82315aa72d7d9a..ecddb82c9a3752 100755 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -37,6 +37,7 @@ _apply_pass, _create_tensor, _current_expected_place, + _current_expected_place_, _dygraph_tracer, _get_paddle_place, _global_flags, diff --git a/python/paddle/hapi/dynamic_flops.py b/python/paddle/hapi/dynamic_flops.py index fcae6e4120ac8e..d02610f6e51848 100644 --- a/python/paddle/hapi/dynamic_flops.py +++ b/python/paddle/hapi/dynamic_flops.py @@ -85,7 +85,6 @@ def flops(net, input_size, custom_ops=None, print_detail=False): ... [1, 1, 28, 28], ... custom_ops= {nn.LeakyReLU: count_leaky_relu}, ... print_detail=True) - >>> # doctest: +SKIP >>> print(FLOPs) <class 'paddle.nn.layer.conv.Conv2D'>'s flops has been counted <class 'paddle.nn.layer.activation.ReLU'>'s flops has been counted @@ -106,7 +105,6 @@ def flops(net, input_size, custom_ops=None, print_detail=False): +--------------+-----------------+-----------------+--------+--------+ Total Flops: 347560 Total Params: 61610 347560 - >>> # doctest: -SKIP """ if isinstance(net, nn.Layer): # If net is a dy2stat model, net.forward is StaticFunction instance, diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 8ca5712a3036c2..55814227385989 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -2399,7 +2399,6 @@ def summary(self, input_size=None, dtype=None): >>> optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) >>> model.prepare(optim, paddle.nn.CrossEntropyLoss()) >>> params_info = model.summary() - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -2424,7 +2423,6 @@ def summary(self, input_size=None, dtype=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP """ assert ( diff --git a/python/paddle/hapi/model_summary.py b/python/paddle/hapi/model_summary.py index df5791a5fd70d8..bedd109b0a532b 100644 --- a/python/paddle/hapi/model_summary.py +++ b/python/paddle/hapi/model_summary.py @@ -78,7 +78,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> lenet = LeNet() >>> params_info = paddle.summary(lenet, (1, 1, 28, 28)) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -103,7 +102,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP >>> # multi input demo >>> class LeNetMultiInput(LeNet): ... def forward(self, inputs, y): @@ -119,7 +117,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> params_info = paddle.summary(lenet_multi_input, ... [(1, 1, 28, 28), (1, 400)], ... dtypes=['float32', 'float32']) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -144,7 +141,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP >>> # list input demo >>> class LeNetListInput(LeNet): ... def forward(self, inputs): @@ -158,7 +154,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> lenet_list_input = LeNetListInput() >>> input_data = [paddle.rand([1, 1, 28, 28]), paddle.rand([1, 400])] >>> params_info = paddle.summary(lenet_list_input, input=input_data) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -183,7 +178,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP >>> # dict input demo >>> class LeNetDictInput(LeNet): ... def forward(self, inputs): @@ -198,7 +192,6 @@ def summary(net, input_size=None, dtypes=None, input=None): >>> input_data = {'x1': paddle.rand([1, 1, 28, 28]), ... 'x2': paddle.rand([1, 400])} >>> params_info = paddle.summary(lenet_dict_input, input=input_data) - >>> # doctest: +SKIP >>> print(params_info) --------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # @@ -223,7 +216,6 @@ def summary(net, input_size=None, dtypes=None, input=None): Estimated Total Size (MB): 0.35 --------------------------------------------------------------------------- {'total_params': 61610, 'trainable_params': 61610} - >>> # doctest: -SKIP """ if input_size is None and input is None: diff --git a/python/paddle/incubate/asp/asp.py b/python/paddle/incubate/asp/asp.py index 041132047dc718..9ffaee1c2b5048 100644 --- a/python/paddle/incubate/asp/asp.py +++ b/python/paddle/incubate/asp/asp.py @@ -47,75 +47,75 @@ def set_excluded_layers(param_names, main_program=None): If None is given, then it would be set as `paddle.static.default_main_program(). Default is None. Examples: - 1. Usage of Dynamic Graph - - .. code-block:: python - - >>> import paddle - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 100) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... prediction = self.linear1(hidden) - ... return prediction - - >>> my_layer = MyLayer() - >>> optimizer = paddle.optimizer.SGD( - ... learning_rate=0.01, parameters=my_layer.parameters()) - - >>> # Need to set excluded layers before calling decorate - >>> paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()]) - - >>> optimizer = paddle.incubate.asp.decorate(optimizer) - - 2. Usage of Static Graph - - .. code-block:: python - - >>> import paddle - - >>> paddle.enable_static() - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 100) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... prediction = self.linear1(hidden) - ... return prediction - - >>> main_program = paddle.static.Program() - >>> startup_program = paddle.static.Program() - - >>> with paddle.static.program_guard(main_program, startup_program): - ... input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) - ... label = paddle.static.data(name='label', shape=[None, 100]) - ... my_layer = MyLayer() - ... prob = my_layer(input_data) - ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - ... - ... # Setup exluded layers out from ASP workflow. - ... # Please note, excluded_layers must be set before calling optimizer.minimize(). - ... paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program) - ... - ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) - ... optimizer = paddle.static.amp.decorate(optimizer ) - ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which - ... # will insert necessary masking operations for ASP workflow. - ... optimizer = paddle.incubate.asp.decorate(optimizer) - ... optimizer.minimize(loss, startup_program) + .. code-block:: python + :name: dynamic-graph + + >>> # Example1: Usage of Dynamic Graph + >>> import paddle + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 100) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... prediction = self.linear1(hidden) + ... return prediction + + >>> my_layer = MyLayer() + >>> optimizer = paddle.optimizer.SGD( + ... learning_rate=0.01, parameters=my_layer.parameters()) + + >>> # Need to set excluded layers before calling decorate + >>> paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()]) + + >>> optimizer = paddle.incubate.asp.decorate(optimizer) + + .. code-block:: python + :name: static-graph + + >>> # Example2: Usage of Static Graph + >>> import paddle + + >>> paddle.enable_static() + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 100) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... prediction = self.linear1(hidden) + ... return prediction + + >>> main_program = paddle.static.Program() + >>> startup_program = paddle.static.Program() + + >>> with paddle.static.program_guard(main_program, startup_program): + ... input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) + ... label = paddle.static.data(name='label', shape=[None, 100]) + ... my_layer = MyLayer() + ... prob = my_layer(input_data) + ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + ... + ... # Setup exluded layers out from ASP workflow. + ... # Please note, excluded_layers must be set before calling optimizer.minimize(). + ... paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program) + ... + ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) + ... optimizer = paddle.static.amp.decorate(optimizer ) + ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + ... # will insert necessary masking operations for ASP workflow. + ... optimizer = paddle.incubate.asp.decorate(optimizer) + ... optimizer.minimize(loss, startup_program) """ if main_program is None: main_program = paddle.static.default_main_program() @@ -134,81 +134,81 @@ def reset_excluded_layers(main_program=None): If None is given, then this function would reset all excluded_layers. Default is None. Examples: - 1. Usage of Dynamic Graph - - .. code-block:: python - - >>> import paddle - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 100) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... prediction = self.linear1(hidden) - ... return prediction - - >>> my_layer = MyLayer() - >>> optimizer = paddle.optimizer.SGD( - ... learning_rate=0.01, parameters=my_layer.parameters()) - - >>> # Need to set excluded layers before calling decorate - >>> paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()]) - >>> # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow. - >>> # Please note, reset_excluded_layers also must be called before calling asp.decorate(). - >>> paddle.incubate.asp.reset_excluded_layers() - - >>> optimizer = paddle.incubate.asp.decorate(optimizer) - - 2. Usage of Static Graph - - .. code-block:: python - - >>> import paddle - - >>> paddle.enable_static() - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 100) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... prediction = self.linear1(hidden) - ... return prediction - - >>> main_program = paddle.static.Program() - >>> startup_program = paddle.static.Program() - - >>> with paddle.static.program_guard(main_program, startup_program): - ... input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) - ... label = paddle.static.data(name='label', shape=[None, 100]) - ... my_layer = MyLayer() - ... prob = my_layer(input_data) - ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - ... - ... # Setup exluded layers out from ASP workflow. - ... # Please note, excluded_layers must be set before calling optimizer.minimize(). - ... paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program) - ... # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow. - ... # Please note, reset_excluded_layers also must be called before calling optimizer.minimize(). - ... paddle.incubate.asp.reset_excluded_layers(main_program) - ... - ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) - ... optimizer = paddle.static.amp.decorate(optimizer ) - ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which - ... # will insert necessary masking operations for ASP workflow. - ... optimizer = paddle.incubate.asp.decorate(optimizer) - ... optimizer.minimize(loss, startup_program) + .. code-block:: python + :name: dynamic-graph + + >>> # Example1: Usage of Dynamic Graph + >>> import paddle + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 100) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... prediction = self.linear1(hidden) + ... return prediction + + >>> my_layer = MyLayer() + >>> optimizer = paddle.optimizer.SGD( + ... learning_rate=0.01, parameters=my_layer.parameters()) + + >>> # Need to set excluded layers before calling decorate + >>> paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()]) + >>> # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow. + >>> # Please note, reset_excluded_layers also must be called before calling asp.decorate(). + >>> paddle.incubate.asp.reset_excluded_layers() + + >>> optimizer = paddle.incubate.asp.decorate(optimizer) + + .. code-block:: python + :name: static-graph + + >>> # Example2: Usage of Static Graph + >>> import paddle + + >>> paddle.enable_static() + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 100) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... prediction = self.linear1(hidden) + ... return prediction + + >>> main_program = paddle.static.Program() + >>> startup_program = paddle.static.Program() + + >>> with paddle.static.program_guard(main_program, startup_program): + ... input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) + ... label = paddle.static.data(name='label', shape=[None, 100]) + ... my_layer = MyLayer() + ... prob = my_layer(input_data) + ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + ... + ... # Setup exluded layers out from ASP workflow. + ... # Please note, excluded_layers must be set before calling optimizer.minimize(). + ... paddle.incubate.asp.set_excluded_layers([my_layer.linear1.full_name()], main_program) + ... # Reset excluded_layers, all supported layers would be included into Automatic SParsity's workflow. + ... # Please note, reset_excluded_layers also must be called before calling optimizer.minimize(). + ... paddle.incubate.asp.reset_excluded_layers(main_program) + ... + ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) + ... optimizer = paddle.static.amp.decorate(optimizer ) + ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + ... # will insert necessary masking operations for ASP workflow. + ... optimizer = paddle.incubate.asp.decorate(optimizer) + ... optimizer.minimize(loss, startup_program) """ ASPHelper.reset_excluded_layers(main_program=main_program) @@ -225,76 +225,76 @@ def decorate(optimizer): Returns: OptimizerWithSparsityGuarantee: A wrapper for ASP to decorate `minimize` function of the given optimizer. Examples: - 1. Usage of Dynamic Graph - - .. code-block:: python - - >>> import paddle - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 32) - ... self.linear2 = paddle.nn.Linear(32, 32) - ... self.linear3 = paddle.nn.Linear(32, 10) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... hidden = self.linear1(hidden) - ... hidden = self.linear2(hidden) - ... prediction = self.linear3(hidden) - ... return prediction - - >>> my_layer = MyLayer() - >>> optimizer = paddle.optimizer.SGD( - ... learning_rate=0.01, parameters=my_layer.parameters()) - - >>> # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which - >>> # will apply necessary masking operations for ASP workflow. - >>> # In dynamic graph mode, ASP would create related mask variables during decoration. - >>> optimizer = paddle.incubate.asp.decorate(optimizer) - - 2. Usage of Static Graph - - .. code-block:: python - - >>> import paddle - - >>> paddle.enable_static() - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 100) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... prediction = self.linear1(hidden) - ... return prediction - - >>> main_program = paddle.static.Program() - >>> startup_program = paddle.static.Program() - - >>> with paddle.static.program_guard(main_program, startup_program): - ... input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) - ... label = paddle.static.data(name='label', shape=[None, 100]) - ... my_layer = MyLayer() - ... prob = my_layer(input_data) - ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - ... - ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) - ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which - ... # will insert necessary masking operations for ASP workflow. - ... # In static graph mode, ASP creates related mask variables - ... # during minimize(). - ... optimizer = paddle.incubate.asp.decorate(optimizer) - ... optimizer.minimize(loss, startup_program) + .. code-block:: python + :name: dynamic-graph + + >>> # Example1: Usage of Dynamic Graph + >>> import paddle + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 32) + ... self.linear2 = paddle.nn.Linear(32, 32) + ... self.linear3 = paddle.nn.Linear(32, 10) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... hidden = self.linear1(hidden) + ... hidden = self.linear2(hidden) + ... prediction = self.linear3(hidden) + ... return prediction + + >>> my_layer = MyLayer() + >>> optimizer = paddle.optimizer.SGD( + ... learning_rate=0.01, parameters=my_layer.parameters()) + + >>> # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which + >>> # will apply necessary masking operations for ASP workflow. + >>> # In dynamic graph mode, ASP would create related mask variables during decoration. + >>> optimizer = paddle.incubate.asp.decorate(optimizer) + + .. code-block:: python + :name: static-graph + + >>> # Example2: Usage of Static Graph + >>> import paddle + + >>> paddle.enable_static() + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 100) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... prediction = self.linear1(hidden) + ... return prediction + + >>> main_program = paddle.static.Program() + >>> startup_program = paddle.static.Program() + + >>> with paddle.static.program_guard(main_program, startup_program): + ... input_data = paddle.static.data(name='data', shape=[None, 3, 224, 224]) + ... label = paddle.static.data(name='label', shape=[None, 100]) + ... my_layer = MyLayer() + ... prob = my_layer(input_data) + ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + ... + ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) + ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + ... # will insert necessary masking operations for ASP workflow. + ... # In static graph mode, ASP creates related mask variables + ... # during minimize(). + ... optimizer = paddle.incubate.asp.decorate(optimizer) + ... optimizer.minimize(loss, startup_program) """ return ASPHelper.decorate(optimizer) @@ -322,116 +322,116 @@ def prune_model(model, n=2, m=4, mask_algo='mask_1d', with_mask=True): Returns: dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable. Examples: - 1. Usage of Dynamic Graph - - .. code-block:: python - - >>> import paddle - >>> import numpy as np - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 32) - ... self.linear2 = paddle.nn.Linear(32, 32) - ... self.linear3 = paddle.nn.Linear(32, 10) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... hidden = self.linear1(hidden) - ... hidden = self.linear2(hidden) - ... prediction = self.linear3(hidden) - ... return prediction - - >>> my_layer = MyLayer() - >>> loss_fn = paddle.nn.MSELoss(reduction='mean') - - >>> optimizer = paddle.optimizer.SGD( - ... learning_rate=0.01, parameters=my_layer.parameters()) - - >>> # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which - >>> # will apply necessary masking operations for ASP workflow. - >>> # In dynamic graph mode, ASP would create related mask variables during decoration. - >>> optimizer = paddle.incubate.asp.decorate(optimizer) - - >>> # Must call paddle.incubate.asp.decorate() first before calling paddle.incubate.asp.prune_model() - >>> paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best') - - >>> for i in range(10): - ... imgs = paddle.to_tensor( - ... np.random.randn(64, 3, 32, 32), - ... dtype='float32', stop_gradient=False) - ... labels = paddle.to_tensor( - ... np.random.randint(10, size=(64, 1)), - ... dtype='float32', stop_gradient=False) - ... output = my_layer(imgs) - ... loss = loss_fn(output, labels) - ... loss.backward() - ... optimizer.step() - ... optimizer.clear_grad() - - 2. Usage of Static Graph - - .. code-block:: python - - >>> import paddle - >>> import numpy as np - - >>> paddle.enable_static() - - >>> class MyLayer(paddle.nn.Layer): - ... def __init__(self): - ... super().__init__() - ... self.conv1 = paddle.nn.Conv2D( - ... in_channels=3, out_channels=4, kernel_size=3, padding=2) - ... self.linear1 = paddle.nn.Linear(4624, 32) - ... self.linear2 = paddle.nn.Linear(32, 32) - ... self.linear3 = paddle.nn.Linear(32, 10) - ... - ... def forward(self, img): - ... hidden = self.conv1(img) - ... hidden = paddle.flatten(hidden, start_axis=1) - ... hidden = self.linear1(hidden) - ... hidden = self.linear2(hidden) - ... prediction = self.linear3(hidden) - ... return prediction - - >>> main_program = paddle.static.Program() - >>> startup_program = paddle.static.Program() - - >>> with paddle.static.program_guard(main_program, startup_program): - ... input_data = paddle.static.data(name='data', shape=[None, 3, 32, 32]) - ... label = paddle.static.data(name='label', shape=[None, 1]) - ... my_layer = MyLayer() - ... prob = my_layer(input_data) - ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) - ... - ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) - ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which - ... # will insert necessary masking operations for ASP workflow. - ... # In static graph mode, ASP creates related mask variables - ... # during minimize(). - ... optimizer = paddle.incubate.asp.decorate(optimizer) - ... optimizer.minimize(loss, startup_program) - - >>> device = paddle.device.get_device() - >>> place = paddle.set_device(device) - - >>> exe = paddle.static.Executor(place) - >>> exe.run(startup_program) - - >>> # Must call exe.run(startup_program) first before calling paddle.asp.prune_model() - >>> paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best') - >>> # it also be accepted to call - >>> # paddle.incubate.asp.prune_model(main_program, mask_algo='mask_2d_best') - - >>> for i in range(10): - ... imgs = np.random.randn(64, 3, 32, 32).astype('float32') - ... labels = np.random.randint(10, size=(64, 1)).astype('float32') - ... exe.run(main_program, feed={'data':imgs, 'label':labels}) + .. code-block:: python + :name: dynamic-graph + + >>> # Example1: Usage of Dynamic Graph + >>> import paddle + >>> import numpy as np + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 32) + ... self.linear2 = paddle.nn.Linear(32, 32) + ... self.linear3 = paddle.nn.Linear(32, 10) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... hidden = self.linear1(hidden) + ... hidden = self.linear2(hidden) + ... prediction = self.linear3(hidden) + ... return prediction + + >>> my_layer = MyLayer() + >>> loss_fn = paddle.nn.MSELoss(reduction='mean') + + >>> optimizer = paddle.optimizer.SGD( + ... learning_rate=0.01, parameters=my_layer.parameters()) + + >>> # Calling paddle.incubate.asp.decorate() to wrap step() in optimizer, which + >>> # will apply necessary masking operations for ASP workflow. + >>> # In dynamic graph mode, ASP would create related mask variables during decoration. + >>> optimizer = paddle.incubate.asp.decorate(optimizer) + + >>> # Must call paddle.incubate.asp.decorate() first before calling paddle.incubate.asp.prune_model() + >>> paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best') + + >>> for i in range(10): + ... imgs = paddle.to_tensor( + ... np.random.randn(64, 3, 32, 32), + ... dtype='float32', stop_gradient=False) + ... labels = paddle.to_tensor( + ... np.random.randint(10, size=(64, 1)), + ... dtype='float32', stop_gradient=False) + ... output = my_layer(imgs) + ... loss = loss_fn(output, labels) + ... loss.backward() + ... optimizer.step() + ... optimizer.clear_grad() + + .. code-block:: python + :name: static-graph + + >>> # Example2: Usage of Static Graph + >>> import paddle + >>> import numpy as np + + >>> paddle.enable_static() + + >>> class MyLayer(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = paddle.nn.Conv2D( + ... in_channels=3, out_channels=4, kernel_size=3, padding=2) + ... self.linear1 = paddle.nn.Linear(4624, 32) + ... self.linear2 = paddle.nn.Linear(32, 32) + ... self.linear3 = paddle.nn.Linear(32, 10) + ... + ... def forward(self, img): + ... hidden = self.conv1(img) + ... hidden = paddle.flatten(hidden, start_axis=1) + ... hidden = self.linear1(hidden) + ... hidden = self.linear2(hidden) + ... prediction = self.linear3(hidden) + ... return prediction + + >>> main_program = paddle.static.Program() + >>> startup_program = paddle.static.Program() + + >>> with paddle.static.program_guard(main_program, startup_program): + ... input_data = paddle.static.data(name='data', shape=[None, 3, 32, 32]) + ... label = paddle.static.data(name='label', shape=[None, 1]) + ... my_layer = MyLayer() + ... prob = my_layer(input_data) + ... loss = paddle.mean(paddle.nn.functional.square_error_cost(prob, label)) + ... + ... optimizer = paddle.optimizer.SGD(learning_rate=0.1) + ... # Calling paddle.incubate.asp.decorate() to wrap minimize() in optimizer, which + ... # will insert necessary masking operations for ASP workflow. + ... # In static graph mode, ASP creates related mask variables + ... # during minimize(). + ... optimizer = paddle.incubate.asp.decorate(optimizer) + ... optimizer.minimize(loss, startup_program) + + >>> device = paddle.device.get_device() + >>> place = paddle.set_device(device) + + >>> exe = paddle.static.Executor(place) + >>> exe.run(startup_program) + + >>> # Must call exe.run(startup_program) first before calling paddle.asp.prune_model() + >>> paddle.incubate.asp.prune_model(my_layer, mask_algo='mask_2d_best') + >>> # it also be accepted to call + >>> # paddle.incubate.asp.prune_model(main_program, mask_algo='mask_2d_best') + + >>> for i in range(10): + ... imgs = np.random.randn(64, 3, 32, 32).astype('float32') + ... labels = np.random.randint(10, size=(64, 1)).astype('float32') + ... exe.run(main_program, feed={'data':imgs, 'label':labels}) """ device = paddle.device.get_device() place = paddle.set_device(device) diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm.py b/python/paddle/incubate/nn/functional/fused_rms_norm.py index 3995cd4a4087d0..99f9c4e72e77d0 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm.py @@ -54,14 +54,15 @@ def fused_rms_norm( Examples: .. code-block:: python - # required: gpu - import paddle + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') - paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16) - paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) - paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) - epsilon = 1e-6 - paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) + >>> paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16) + >>> paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) + >>> paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float16) + >>> epsilon = 1e-6 + >>> paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1) """ if in_dynamic_or_pir_mode(): return _C_ops.rms_norm( diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index 0b667687c114bf..adfcdc233fe567 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -14,6 +14,7 @@ from paddle import _C_ops +from paddle.base.layer_helper import LayerHelper from paddle.framework import in_dynamic_mode @@ -91,6 +92,28 @@ def fused_rotary_position_embedding( q, k, v, sin, cos, position_ids, use_neox_rotary_style ) - raise RuntimeError( - "This feature is currently supported only in dynamic mode and with CUDAPlace." + helper = LayerHelper('fused_rotary_position_embedding', **locals()) + out_q = helper.create_variable_for_type_inference(dtype=q.dtype) + out_k = ( + helper.create_variable_for_type_inference(dtype=k.dtype) if k else None ) + out_v = ( + helper.create_variable_for_type_inference(dtype=v.dtype) if v else None + ) + helper.append_op( + type='fused_rotary_position_embedding', + inputs={ + 'q': q, + 'k': k, + 'v': v, + 'sin': sin, + 'cos': cos, + 'position_ids': position_ids, + }, + outputs={'out_q': out_q, 'out_k': out_k, 'out_v': out_v}, + attrs={ + 'use_neox_rotary_style': use_neox_rotary_style, + }, + ) + + return out_q, out_k, out_v diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 355b5916b5ddb2..c4cf8abfdb3546 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -56,20 +56,20 @@ def fused_feedforward( This operator only supports running on GPU. The function of the operator is consistent with the following pseudo code: - .. code-block:: python - - residual = x - if pre_layer_norm: - out = layer_norm1(x) - else: - out = x - out = linear2(dropout1(activation(linear1(src)))) - if add_residual: - out = residual + dropout2(out) - else: - out = dropout2(out) - if not pre_layer_norm: - out = layer_norm2(out) + .. code-block:: text + + >>> residual = x + >>> if pre_layer_norm: + ... out = layer_norm1(x) + ... else: + ... out = x + >>> out = linear2(dropout1(activation(linear1(src)))) + >>> if add_residual: + ... out = residual + dropout2(out) + ... else: + ... out = dropout2(out) + >>> if not pre_layer_norm: + ... out = layer_norm2(out) Args: @@ -110,16 +110,17 @@ def fused_feedforward( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - x = paddle.randn(shape=(1, 8, 8), dtype="float32") - linear1_weight = paddle.randn(shape=(8, 8), dtype="float32") - linear2_weight = paddle.randn(shape=(8, 8), dtype="float32") - out = F.fused_feedforward(x, linear1_weight, linear2_weight) - print(out.shape) - # (1, 8, 8) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> x = paddle.randn(shape=(1, 8, 8), dtype="float32") + >>> linear1_weight = paddle.randn(shape=(8, 8), dtype="float32") + >>> linear2_weight = paddle.randn(shape=(8, 8), dtype="float32") + >>> out = F.fused_feedforward(x, linear1_weight, linear2_weight) + >>> print(out.shape) + [1, 8, 8] """ _verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout2_rate) @@ -288,9 +289,9 @@ def fused_bias_dropout_residual_layer_norm( The fused_bias_dropout_residual_layer_norm operator. The pseudo code is as follows: - .. code-block:: python + .. code-block:: text - y = layer_norm(residual + dropout(bias + x)) + >>> y = layer_norm(residual + dropout(bias + x)) Parameters: x (Tensor): The input tensor. The shape is `[*, embed\_dim]`. @@ -323,21 +324,22 @@ def fused_bias_dropout_residual_layer_norm( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - # input: [batch_size, seq_len, embed_dim] - x = paddle.rand(shape=(2, 4, 128), dtype="float32") - # residual: [batch_size, seq_len, embed_dim] - residual = paddle.rand(shape=(2, 4, 128), dtype="float32") - # linear bias: [embed_dim] - bias = paddle.rand(shape=[128], dtype="float32") - # output: [batch_size, seq_len, embed_dim] - output = F.fused_bias_dropout_residual_layer_norm( - x, residual, bias) - # [2, 4, 128] - print(output.shape) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> # input: [batch_size, seq_len, embed_dim] + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> # residual: [batch_size, seq_len, embed_dim] + >>> residual = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> # linear bias: [embed_dim] + >>> bias = paddle.rand(shape=[128], dtype="float32") + >>> # output: [batch_size, seq_len, embed_dim] + >>> output = F.fused_bias_dropout_residual_layer_norm( + ... x, residual, bias) + >>> print(output.shape) + [2, 4, 128] """ seed = None @@ -493,35 +495,35 @@ def fused_multi_head_attention( to information from different representation subspaces. This API only support self_attention. The pseudo code is as follows: - .. code-block:: python - - residual = x - if pre_layer_norm: - out = layer_norm(x) - else: - out = x - # compute q, k, v - out = matmul(out, qkv_weight) + qkv_bias - out = transpose(out, perm=[2, 0, 3, 1, 4]) - # extract q, k and v from out - q = out[0:1,::] * (head_dim ** -0.5) - k = out[1:2,::] - v = out[2:3,::] - out = matmul(q, k, transpose_y=True) - out = out + attn_mask - out = softmax(out) - out = dropout(out) - out = matmul(out, v) - # combine heads - out = transpose(out, perm=[0, 2, 1, 3]) - # project to output - out = linear(out) - if add_residual: - out = residual + dropout(out) - else: - out = dropout(out) - if not pre_layer_norm: - out = layer_norm(out) + .. code-block:: text + + >>> residual = x + >>> if pre_layer_norm: + ... out = layer_norm(x) + ... else: + ... out = x + >>> # compute q, k, v + >>> out = matmul(out, qkv_weight) + qkv_bias + >>> out = transpose(out, perm=[2, 0, 3, 1, 4]) + >>> # extract q, k and v from out + >>> q = out[0:1,::] * (head_dim ** -0.5) + >>> k = out[1:2,::] + >>> v = out[2:3,::] + >>> out = matmul(q, k, transpose_y=True) + >>> out = out + attn_mask + >>> out = softmax(out) + >>> out = dropout(out) + >>> out = matmul(out, v) + >>> # combine heads + >>> out = transpose(out, perm=[0, 2, 1, 3]) + >>> # project to output + >>> out = linear(out) + >>> if add_residual: + ... out = residual + dropout(out) + ... else: + ... out = dropout(out) + >>> if not pre_layer_norm: + ... out = layer_norm(out) Parameters: @@ -581,30 +583,31 @@ def fused_multi_head_attention( .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - # input: [batch_size, seq_len, embed_dim] - x = paddle.rand(shape=(2, 4, 128), dtype="float32") - # qkv_weight: [3, num_head, head_dim, embed_dim] - qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") - # qkv_bias: [3, num_head, head_dim] - qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") - # linear_weight: [embed_dim, embed_dim] - linear_weight = paddle.rand(shape=(128, 128), dtype="float32") - # linear_bias: [embed_dim] - linear_bias = paddle.rand(shape=[128], dtype="float32") - # self attention mask: [batch_size, num_heads, seq_len, seq_len] - attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") - - # output: [batch_size, seq_len, embed_dim] - output = F.fused_multi_head_attention( - x, qkv_weight, linear_weight, False, - None, None, None, None, 1e-5, qkv_bias, - linear_bias, None, attn_mask) - # [2, 4, 128] - print(output.shape) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> # input: [batch_size, seq_len, embed_dim] + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + >>> # qkv_weight: [3, num_head, head_dim, embed_dim] + >>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + >>> # qkv_bias: [3, num_head, head_dim] + >>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + >>> # linear_weight: [embed_dim, embed_dim] + >>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + >>> # linear_bias: [embed_dim] + >>> linear_bias = paddle.rand(shape=[128], dtype="float32") + >>> # self attention mask: [batch_size, num_heads, seq_len, seq_len] + >>> attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") + + >>> # output: [batch_size, seq_len, embed_dim] + >>> output = F.fused_multi_head_attention( + ... x, qkv_weight, linear_weight, False, + ... None, None, None, None, 1e-5, qkv_bias, + ... linear_bias, None, attn_mask) + >>> print(output.shape) + [2, 4, 128] """ seed = None @@ -906,39 +909,39 @@ def fused_multi_transformer( This operator only supports running on GPU. The function of the transformer layer is consistent with the following pseudo code: - .. code-block:: python - - if pre_layer_norm: - out = layer_norm(x) - out = qkv_linear(out) + qkv_bias - else: - out = qkv_linear(x) + qkv_bias - out = transpose(out, perm=[2, 0, 3, 1, 4]) - # extract q, k and v from out. - q = out[0:1, ::] - k = out[1:2, ::] - v = out[2:3, ::] - out = q * k^t - out = attn_mask + out - out = softmax(out) - out = dropout(out) - out = out * v - out = transpose(out, perm=[0, 2, 1, 3]) - out = linear(out) - if pre_layer_norm: - out = x + dropout(out + bias) - else: - out = layer_norm(x + dropout(out + bias)) - - residual = out; - if pre_layer_norm: - out = ffn_layer_norm(out) - out = ffn1_linear(out) - out = dropout(activation(out + ffn1_bias)) - out = ffn2_linear(out) - out = residual + dropout(out + ffn2_bias) - if not pre_layer_norm: - out = ffn_layer_norm(out) + .. code-block:: text + + >>> if pre_layer_norm: + ... out = layer_norm(x) + ... out = qkv_linear(out) + qkv_bias + ... else: + ... out = qkv_linear(x) + qkv_bias + >>> out = transpose(out, perm=[2, 0, 3, 1, 4]) + >>> # extract q, k and v from out. + >>> q = out[0:1, ::] + >>> k = out[1:2, ::] + >>> v = out[2:3, ::] + >>> out = q * k^t + >>> out = attn_mask + out + >>> out = softmax(out) + >>> out = dropout(out) + >>> out = out * v + >>> out = transpose(out, perm=[0, 2, 1, 3]) + >>> out = linear(out) + >>> if pre_layer_norm: + ... out = x + dropout(out + bias) + ... else: + ... out = layer_norm(x + dropout(out + bias)) + + >>> residual = out; + >>> if pre_layer_norm: + ... out = ffn_layer_norm(out) + >>> out = ffn1_linear(out) + >>> out = dropout(activation(out + ffn1_bias)) + >>> out = ffn2_linear(out) + >>> out = residual + dropout(out + ffn2_bias) + >>> if not pre_layer_norm: + ... out = ffn_layer_norm(out) Args: x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16 or float32, the shape is `[batch\_size, sequence\_length, d\_model]`. @@ -996,48 +999,49 @@ def fused_multi_transformer( Examples: .. code-block:: python - # required: gpu - import paddle - import paddle.incubate.nn.functional as F - - # input: [batch_size, seq_len, embed_dim] - x = paddle.rand(shape=(2, 4, 128), dtype="float32") - - # ln_scale: [embed_dim], ln_bias: [embed_dim] - ln_scale = paddle.rand(shape=(128,), dtype="float32") - ln_bias = paddle.rand(shape=(128,), dtype="float32") - - # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim] - qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") - qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") - - # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim] - linear_weight = paddle.rand(shape=(128, 128), dtype="float32") - linear_bias = paddle.rand(shape=(128,), dtype="float32") - - # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim] - ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32") - ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32") - - # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim] - ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32") - ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32") - - # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim] - ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32") - ffn2_bias = paddle.rand(shape=(128,), dtype="float32") - - # self attention mask: [batch_size, 1, seq_len, seq_len] - attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32") - - # output: [batch_size, seq_len, embed_dim] - output = F.fused_multi_transformer( - x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias], - [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias], - [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias], - attn_mask=attn_mask) - # [2, 4, 128] - print(output.shape) + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> paddle.device.set_device('gpu') + >>> import paddle.incubate.nn.functional as F + + >>> # input: [batch_size, seq_len, embed_dim] + >>> x = paddle.rand(shape=(2, 4, 128), dtype="float32") + + >>> # ln_scale: [embed_dim], ln_bias: [embed_dim] + >>> ln_scale = paddle.rand(shape=(128,), dtype="float32") + >>> ln_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # qkv_weight: [3, num_head, head_dim, embed_dim], qkv_bias: [3, num_head, head_dim] + >>> qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + >>> qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + + >>> # linear_weight: [embed_dim, embed_dim], linear_bias: [embed_dim] + >>> linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + >>> linear_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # ffn_ln_scale: [embed_dim], ffn_ln_bias: [embed_dim] + >>> ffn_ln_scale = paddle.rand(shape=(128,), dtype="float32") + >>> ffn_ln_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # ffn1_weight: [embed_dim, 4*embed_dim], ffn1_bias: [4*embed_dim] + >>> ffn1_weight = paddle.rand(shape=(128, 4*128), dtype="float32") + >>> ffn1_bias = paddle.rand(shape=(4*128,), dtype="float32") + + >>> # ffn2_weight: [4*embed_dim, embed_dim], ffn2_bias: [embed_dim] + >>> ffn2_weight = paddle.rand(shape=(4*128, 128), dtype="float32") + >>> ffn2_bias = paddle.rand(shape=(128,), dtype="float32") + + >>> # self attention mask: [batch_size, 1, seq_len, seq_len] + >>> attn_mask = paddle.rand(shape=(2, 1, 4, 4), dtype="float32") + + >>> # output: [batch_size, seq_len, embed_dim] + >>> output = F.fused_multi_transformer( + ... x, [ln_scale], [ln_bias], [qkv_weight], [qkv_bias], + ... [linear_weight], [linear_bias], [ffn_ln_scale], [ffn_ln_bias], + ... [ffn1_weight], [ffn1_bias], [ffn2_weight], [ffn2_bias], + ... attn_mask=attn_mask) + >>> print(output.shape) + [2, 4, 128] """ if mode not in ('downscale_in_infer', 'upscale_in_train'): raise ValueError( diff --git a/python/paddle/incubate/nn/layer/fused_dropout_nd.py b/python/paddle/incubate/nn/layer/fused_dropout_nd.py index ded171158fe3dc..09f083da88c741 100644 --- a/python/paddle/incubate/nn/layer/fused_dropout_nd.py +++ b/python/paddle/incubate/nn/layer/fused_dropout_nd.py @@ -54,6 +54,7 @@ class FusedDropout(paddle.nn.Layer): .. code-block:: python >>> import paddle + >>> paddle.seed(2023) >>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32") >>> m = paddle.incubate.nn.FusedDropout(p=0.5) @@ -61,15 +62,15 @@ class FusedDropout(paddle.nn.Layer): >>> y_train = m(x) >>> print(y_train) Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, - [[2., 0., 6.], - [0., 0., 0.]]) + [[0., 0., 6.], + [0., 0., 0.]]) >>> m.eval() # switch the model to test phase >>> y_test = m(x) >>> print(y_test) Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, - [[1., 2., 3.], - [4., 5., 6.]]) + [[1., 2., 3.], + [4., 5., 6.]]) """ def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None): diff --git a/python/paddle/incubate/optimizer/pipeline.py b/python/paddle/incubate/optimizer/pipeline.py index 6c0e80b1f57104..b7ae315576d791 100644 --- a/python/paddle/incubate/optimizer/pipeline.py +++ b/python/paddle/incubate/optimizer/pipeline.py @@ -48,47 +48,47 @@ class PipelineOptimizer: Examples: .. code-block:: python - import paddle - import paddle.base as base - import paddle.base.layers as layers - import numpy as np - - paddle.enable_static() - with base.device_guard("gpu:0"): - x = paddle.static.data(name='x', shape=[-1, 1], dtype='int64', lod_level=0) - y = paddle.static.data(name='y', shape=[-1, 1], dtype='int64', lod_level=0) - data_loader = base.io.DataLoader.from_generator( - feed_list=[x, y], - capacity=64, - use_double_buffer=True, - iterable=False) - - emb_x = layers.embedding(input=x, param_attr=base.ParamAttr(name="embx"), size=[10,2], is_sparse=False) - emb_y = layers.embedding(input=y, param_attr=base.ParamAttr(name="emby",learning_rate=0.9), size=[10,2], is_sparse=False) - - with base.device_guard("gpu:1"): - concat = layers.concat([emb_x, emb_y], axis=1) - fc = paddle.static.nn.fc(x=concat, name="fc", size=1, num_flatten_dims=1, bias_attr=False) - loss = paddle.mean(fc) - optimizer = paddle.optimizer.SGD(learning_rate=0.5) - optimizer = paddle.incubate.optimizer.PipelineOptimizer(optimizer) - optimizer.minimize(loss) - - def train_reader(): - for _ in range(4): - x = np.random.random(size=[1]).astype('int64') - y = np.random.random(size=[1]).astype('int64') - yield x, y - data_loader.set_sample_generator(train_reader, batch_size=1) - - place = base.CUDAPlace(0) - exe = base.Executor(place) - exe.run(base.default_startup_program()) - batch_size = 1 - data_loader.start() - exe.train_from_dataset( - base.default_main_program()) - data_loader.reset() + >>> import paddle + >>> import paddle.base as base + >>> import paddle.base.layers as layers + >>> import numpy as np + + >>> paddle.enable_static() + >>> with base.device_guard("gpu:0"): + ... x = paddle.static.data(name='x', shape=[-1, 1], dtype='int64', lod_level=0) + ... y = paddle.static.data(name='y', shape=[-1, 1], dtype='int64', lod_level=0) + ... data_loader = base.io.DataLoader.from_generator( + ... feed_list=[x, y], + ... capacity=64, + ... use_double_buffer=True, + ... iterable=False) + + ... emb_x = layers.embedding(input=x, param_attr=base.ParamAttr(name="embx"), size=[10,2], is_sparse=False) + ... emb_y = layers.embedding(input=y, param_attr=base.ParamAttr(name="emby",learning_rate=0.9), size=[10,2], is_sparse=False) + + >>> with base.device_guard("gpu:1"): + ... concat = layers.concat([emb_x, emb_y], axis=1) + ... fc = paddle.static.nn.fc(x=concat, name="fc", size=1, num_flatten_dims=1, bias_attr=False) + ... loss = paddle.mean(fc) + >>> optimizer = paddle.optimizer.SGD(learning_rate=0.5) + >>> optimizer = paddle.incubate.optimizer.PipelineOptimizer(optimizer) + >>> optimizer.minimize(loss) + + >>> def train_reader(): + ... for _ in range(4): + ... x = np.random.random(size=[1]).astype('int64') + ... y = np.random.random(size=[1]).astype('int64') + ... yield x, y + >>> data_loader.set_sample_generator(train_reader, batch_size=1) + + >>> place = paddle.CUDAPlace(0) + >>> exe = paddle.static.Executor(place) + >>> exe.run(paddle.static.default_startup_program()) + >>> batch_size = 1 + >>> data_loader.start() + >>> exe.train_from_dataset( + ... paddle.static.default_main_program()) + >>> data_loader.reset() """ def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0): diff --git a/python/paddle/incubate/optimizer/recompute.py b/python/paddle/incubate/optimizer/recompute.py index 9cbd8894f18897..2545115fa0d015 100644 --- a/python/paddle/incubate/optimizer/recompute.py +++ b/python/paddle/incubate/optimizer/recompute.py @@ -49,45 +49,57 @@ class RecomputeOptimizer(Optimizer): Examples: .. code-block:: python - import paddle - import paddle.base as base - import numpy as np - - paddle.enable_static() - - def gen_data(): - return {"x": np.random.random(size=(32, 32)).astype('float32'), - "y": np.random.randint(2, size=(32, 1)).astype('int64')} - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - print(input_x) - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - sgd.minimize(cost) - - print("Finished optimize") - place = base.CPUPlace() - exe = base.Executor(place) - exe.run(base.default_startup_program()) - step = 10 - - for i in range(step): - cost_val = exe.run(feed=gen_data(), - program=base.default_main_program(), - fetch_list=[cost.name]) - print("step=%d cost=%f" % (i, cost_val[0])) + >>> import paddle + >>> import numpy as np + + >>> paddle.enable_static() + + >>> def gen_data(): + ... return {"x": np.random.random(size=(32, 32)).astype('float32'), + ... "y": np.random.randint(2, size=(32, 1)).astype('int64')} + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... print(input_x) + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> sgd.minimize(cost) + + >>> print("Finished optimize") + Finished optimize + >>> place = paddle.CPUPlace() + >>> exe = paddle.static.Executor(place) + >>> exe.run(paddle.static.default_startup_program()) + >>> step = 10 + + >>> for i in range(step): + ... cost_val = exe.run(feed=gen_data(), + ... program=paddle.static.default_main_program(), + ... fetch_list=[cost.name]) + ... print("step=%d cost=%f" % (i, cost_val[0])) + var x : LOD_TENSOR.shape(-1, 32).dtype(float32).stop_gradient(True) + Finished optimize + step=0 cost=0.737203 + step=1 cost=1.308077 + step=2 cost=0.768422 + step=3 cost=1.239475 + step=4 cost=0.882643 + step=5 cost=0.738027 + step=6 cost=0.819374 + step=7 cost=0.818534 + step=8 cost=0.753692 + step=9 cost=0.787448 """ @@ -132,33 +144,34 @@ def load(self, state_dict): Examples: .. code-block:: python - import paddle - import paddle.base as base - - paddle.enable_static() - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - try: - state_dict = {} - sgd.load(state_dict) - except NotImplementedError as e: - print(e) + >>> import paddle + + >>> paddle.enable_static() + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> try: + ... state_dict = {} + ... sgd.load(state_dict) + >>> except NotImplementedError as e: + ... print(e) + load function is not supported by Recompute Optimizer for now """ raise NotImplementedError( "load function is not supported by Recompute Optimizer for now" @@ -177,42 +190,42 @@ def apply_gradients(self, params_grads): Examples: .. code-block:: python - import paddle - import paddle.base as base - import paddle.base.framework as framework - - paddle.enable_static() - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - params_grads = sgd.backward( - cost, - startup_program=None, - parameter_list=None, - no_grad_set=None) - - program = cost.block.program - with framework.program_guard(program, None): - optimize_ops = sgd.apply_gradients(params_grads) - - print("Finished apply gradients") + >>> import paddle + >>> import paddle.base.framework as framework + + >>> paddle.enable_static() + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> params_grads = sgd.backward( + ... cost, + ... startup_program=None, + ... parameter_list=None, + ... no_grad_set=None) + + >>> program = cost.block.program + >>> with framework.program_guard(program, None): + ... optimize_ops = sgd.apply_gradients(params_grads) + + >>> print("Finished apply gradients") + Finished apply gradients """ return self._optimizer.apply_gradients(params_grads=params_grads) @@ -651,36 +664,36 @@ def backward( Examples: .. code-block:: python - import paddle - import paddle.base as base - - paddle.enable_static() - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - params_grads = sgd.backward( - cost, - startup_program=None, - parameter_list=None, - no_grad_set=None) - print("Finished backward") + >>> import paddle + + >>> paddle.enable_static() + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> params_grads = sgd.backward( + ... cost, + ... startup_program=None, + ... parameter_list=None, + ... no_grad_set=None) + >>> print("Finished backward") + Finished backward """ assert ( self._checkpoints is not None @@ -733,39 +746,41 @@ def apply_optimize(self, loss, startup_program, params_grads): params_grads (list): list of (param, grad) pair to do optimization. Examples: .. code-block:: python - import paddle - import paddle.base as base - paddle.enable_static() - - def mlp(input_x, input_y, hid_dim=128, label_dim=2): - fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) - prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') - cost = paddle.nn.functional.cross_entropy( - input=prediction, label=input_y, - reduction='none', use_softmax=False - ) - sum_cost = paddle.mean(cost) - return sum_cost, fc_1, prediction - - input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') - input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') - cost, fc_1, pred = mlp(input_x, input_y) - print("Finished FF") - - sgd = paddle.optimizer.Adam(learning_rate=0.01) - sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) - sgd._set_checkpoints([fc_1, pred]) - params_grads = sgd.backward( - cost, - startup_program=None, - parameter_list=None, - no_grad_set=None) - - optimize_ops = sgd.apply_optimize( - cost, startup_program=None, params_grads=params_grads) - - print("Finished apply_optimize") + >>> import paddle + + >>> paddle.enable_static() + + >>> def mlp(input_x, input_y, hid_dim=128, label_dim=2): + ... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim) + ... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax') + ... cost = paddle.nn.functional.cross_entropy( + ... input=prediction, label=input_y, + ... reduction='none', use_softmax=False + ... ) + ... sum_cost = paddle.mean(cost) + ... return sum_cost, fc_1, prediction + + >>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32') + >>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64') + >>> cost, fc_1, pred = mlp(input_x, input_y) + >>> print("Finished FF") + Finished FF + + >>> sgd = paddle.optimizer.Adam(learning_rate=0.01) + >>> sgd = paddle.incubate.optimizer.RecomputeOptimizer(sgd) + >>> sgd._set_checkpoints([fc_1, pred]) + >>> params_grads = sgd.backward( + ... cost, + ... startup_program=None, + ... parameter_list=None, + ... no_grad_set=None) + + >>> optimize_ops = sgd.apply_optimize( + ... cost, startup_program=None, params_grads=params_grads) + + >>> print("Finished apply_optimize") + Finished apply_optimize """ func = ( diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 642b4c8b9529e8..65da105499b205 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -692,13 +692,15 @@ class SymbolicStaticFunction(StaticFunction): def __init__(self, function, input_spec=None, **kwargs): if input_spec is not None: warnings.warn( - "\nSymbolic Trace don't support input_spec arguments. It will Will not produce any effect.\n" + "\nSymbolic Trace don't support input_spec arguments. It will not produce any effect.\n" "1. You can disable fallback mode by `paddle.jit.to_static(enable_fallback=False)` to switch to AST to static, then you can assign input spec.\n" ) super().__init__(function, input_spec, **kwargs) self.last_call_input_spec = None def _perform_call(self, *args, **kwargs): + from ..sot import symbolic_translate + args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) ( input_args_with_spec, @@ -706,16 +708,6 @@ def _perform_call(self, *args, **kwargs): ) = self._function_spec.args_to_input_spec(args, kwargs) self.last_call_input_spec = input_args_with_spec - try: - from sot import symbolic_translate - except: - import os - - os.system( - "pip install git+https://github.com/PaddlePaddle/PaddleSOT@develop" - ) - from sot import symbolic_translate - build_strategy = self._kwargs.get("build_strategy", None) backend = self._kwargs.get("backend", None) traced_fun = symbolic_translate( diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index fd5eba66c76842..1eab7edc738bfd 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -178,6 +178,21 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ) +def create_undefined_variable_local(): + helper = LayerHelper('create_undefined_variable', **locals()) + var = helper.create_variable( + name=unique_name.generate("undefined_var"), + shape=[1], + dtype="float64", + type=core.VarDesc.VarType.LOD_TENSOR, + stop_gradient=False, + is_data=True, + need_check_feed=False, + ) + paddle.assign(RETURN_NO_VALUE_MAGIC_NUM, var) + return var + + def create_undefined_variable(): var = data_layer_not_check( unique_name.generate("undefined_var"), [1], "float64" diff --git a/python/paddle/jit/dy2static/utils_helper.py b/python/paddle/jit/dy2static/utils_helper.py index 601e3241d7464c..b54c026745f700 100644 --- a/python/paddle/jit/dy2static/utils_helper.py +++ b/python/paddle/jit/dy2static/utils_helper.py @@ -183,3 +183,14 @@ def type_from_annotation(annotation): # raise warning if not found warn("Currently we don't support annotation: %s" % annotation_str) return NodeVarType.UNKNOWN + + +def set_dynamic_shape(variable, shape_list): + if paddle.base.dygraph.base.in_to_static_mode(): + assert isinstance( + variable, paddle.base.framework.Variable + ), "In to_static mode, variable must be a Variable." + variable.desc.set_shape(shape_list) + else: + # in dygraph mode, dynamic shape is not needed, just do nothing. + return diff --git a/python/paddle/jit/sot/__init__.py b/python/paddle/jit/sot/__init__.py new file mode 100644 index 00000000000000..1b45c0c55389b2 --- /dev/null +++ b/python/paddle/jit/sot/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import psdb # noqa: F401 +from .opcode_translator.breakpoint import ( # noqa: F401 + BM, + add_breakpoint, + add_event, +) +from .opcode_translator.skip_files import skip_function # noqa: F401 +from .translate import symbolic_translate # noqa: F401 diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py new file mode 100644 index 00000000000000..8ea3ec28f19a4b --- /dev/null +++ b/python/paddle/jit/sot/infer_meta.py @@ -0,0 +1,282 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.amp.auto_cast import amp_state +from paddle.base.unique_name import UniqueNameGenerator +from paddle.base.unique_name import guard as UniqueNameGuard +from paddle.static import Program +from paddle.utils import flatten, is_sequence + +from .utils import Cache, Singleton, map_if_extend, meta_str + + +class MetaInfo: + def __init__( + self, shape, dtype, stop_gradient, name, persistable, type, place + ): + self.name = name + self.persistable = persistable + self.type = type + self.place = place + self.shape = shape + self.dtype = dtype + self.stop_gradient = stop_gradient + + @staticmethod + def from_tensor(tensor): + # We always use float32 in simulation if AMP is enabled. + dtype = tensor.dtype + current_amp_state = amp_state() + if ( + dtype == paddle.float16 + and current_amp_state is not None + and current_amp_state["dtype"] == "float16" + ): + dtype = paddle.float32 + return MetaInfo( + list(tensor.shape), + dtype, + tensor.stop_gradient, + tensor.name, + tensor.persistable, + tensor.type, + tensor.place, + ) + + def is_dynamic_shape(self): + """ + if -1 in shape, return True + else: return False + """ + return -1 in self.shape + + def to_input_spec(self): + return paddle.static.InputSpec( + self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient + ) + + def guard_str(self): + return f"({self.shape}, {self.dtype}, {self.stop_gradient})" + + def __repr__(self): + return meta_str(self.shape, self.dtype, self.stop_gradient) + + def __eq__(self, meta): + return ( + self.shape == meta.shape + and self.dtype == meta.dtype + and self.stop_gradient == meta.stop_gradient + ) + + def __hash__(self): + return hash((tuple(self.shape), self.dtype, self.stop_gradient)) + + +@Singleton +class VariableCreator: + """ + We use the static graph Variable to infer the meta information of Tensor. + This singleton class is used to create Variable for infer meta. + """ + + def __init__(self): + self.var_cache = {} + self.main_program = Program() + self.startup_program = Program() + self.var_name_generator = UniqueNameGenerator("infer_meta_variable_") + + def gen_name(self, meta): + name = f"{meta.dtype}_{meta.stop_gradient}" + for l in meta.shape: + name += f"_{l}" + return name + + def create_var(self, meta): + var = self.main_program.global_block().create_var( + shape=meta.shape, + dtype=meta.dtype, + stop_gradient=meta.stop_gradient, + ) + assert not isinstance( + var, paddle.Tensor + ), "Expect a Variable, but got a Tensor." + return var + + def get_variable(self, meta): + var_feature_name = self.gen_name(meta) + if var_feature_name not in self.var_cache: + self.var_cache[var_feature_name] = self.create_var(meta) + return self.var_cache[var_feature_name] + + def infer_meta(self, func, *args, **kwargs): + with paddle.base.framework._dygraph_guard(None), UniqueNameGuard( + self.var_name_generator + ): + args, kwargs = convert_meta_to_variable( + args + ), convert_meta_to_variable(kwargs) + + with paddle.static.program_guard( + self.main_program, self.startup_program + ): + if isinstance(func, str): + # TODO(Aurelius84): Is length of args always greater than 0? + # Do we need add condition check here? + out = getattr(args[0], func)(*args[1:], **kwargs) + else: + out = func(*args, **kwargs) + + return convert_variable_to_meta_info(out) + + +def convert_meta_to_variable(args): + return map_if_extend( + args, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: VariableCreator().get_variable(x), + false_fn=lambda x: x, + ) + + +def convert_meta_to_input_spec(args): + return map_if_extend( + args, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: x.to_input_spec(), + # TODO(xiongkun): can x be tensor ? + false_fn=lambda x: paddle.static.InputSpec.from_tensor(x) + if isinstance(x, paddle.Tensor) + else x, + ) + + +def convert_variable_to_meta_info(args): + return map_if_extend( + args, + pred=lambda x: isinstance(x, paddle.static.Variable), + true_fn=lambda x: MetaInfo.from_tensor(x), + false_fn=lambda x: x, + ) + + +def infer_meta(func, *args, **kwargs): + fn = SpecialInferMeta().get_infermeta_fn(func) + if fn: + return fn(*args, **kwargs) + return VariableCreator().infer_meta(func, *args, **kwargs) + + +def infer_meta_for_layer(layer, *args, **kwargs): + assert isinstance( + layer, paddle.nn.Layer + ), f"Expect a Layer, but got {layer}." + layer = paddle.jit.to_static(layer, enable_fallback=False) + + args_, kwargs_ = convert_meta_to_input_spec((args, kwargs)) + + ( + concrete_program, + partial_program_layer, + ) = layer.forward.get_concrete_program(*args_, **kwargs_) + + out = partial_program_layer._restore_out( + paddle.utils.flatten( + convert_variable_to_meta_info(concrete_program.outputs) + ) + ) + layer.forward.rollback() + return out + + +@Singleton +class SpecialInferMeta: + """ + There are some functions that cannot be inferred directly through static graph, + and need to be implemented manually. This class is used to implement infer meta + for these functions. + """ + + def __init__(self): + pass + + def get_infermeta_fn(self, fn): + try: + funcname = fn.__name__ + return getattr(self, f"infermeta_{funcname}") + except: + pass + return None + + def infermeta_grad( + self, + outputs, + inputs, + grad_outputs=None, + retain_graph=None, + create_graph=False, + only_inputs=True, + allow_unused=False, + no_grad_vars=None, + ): + if not is_sequence(inputs): + inputs = [inputs] + return inputs + + +@Singleton +class InferMetaCache(Cache): + def key_fn( + self, func, *args, **kwargs + ): # args & kwargs have transformed to MetaInfo + try: + retval = hash( + ( + func, + tuple(flatten(args)), + tuple(kwargs.keys()), + tuple(flatten(kwargs)), + ) + ) + except Exception as e: + return None + return retval + + def value_fn(self, func, *args, **kwargs): + return infer_meta(func, *args, **kwargs) + + +@Singleton +class LayerInferMetaCache(Cache): + def key_fn(self, layer, *args, **kwargs): + params = [ + MetaInfo.from_tensor(x) + for x in layer.parameters(include_sublayers=True) + ] + try: + retval = hash( + ( + layer, + tuple(params), + tuple(flatten(args)), + tuple(kwargs.keys()), + tuple(flatten(kwargs)), + ) + ) + except Exception as e: + return None + return retval + + def value_fn(self, layer, *args, **kwargs): + return infer_meta_for_layer(layer, *args, **kwargs) diff --git a/python/paddle/jit/sot/opcode_translator/__init__.py b/python/paddle/jit/sot/opcode_translator/__init__.py new file mode 100644 index 00000000000000..bf230190e3e112 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transform import eval_frame_callback # noqa: F401 diff --git a/python/paddle/jit/sot/opcode_translator/breakpoint.py b/python/paddle/jit/sot/opcode_translator/breakpoint.py new file mode 100644 index 00000000000000..6f3217dd8776ea --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/breakpoint.py @@ -0,0 +1,179 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import traceback +from dataclasses import dataclass + +from ..opcode_translator.instruction_utils import instrs_info +from ..utils import Singleton, log +from .executor.opcode_executor import OpcodeExecutorBase + +# this file is a debug utils files for quick debug +# >>> sot.add_breakpoint(file, line) +# >>> sot.remove_breakpoint(file, line) + + +@dataclass +class Breakpoint: + file: str + line: int + co_name: str + offset: int + + def __hash__(self): + return hash((self.file, self.line, self.co_name, self.offset)) + + +@Singleton +class BreakpointManager: + def __init__(self): + self.breakpoints = set() + self.executors = OpcodeExecutorBase.call_stack + self.activate = 0 + self.record_event = [] + + def clear_event(self, event): + self.record_event.clear() + + def add_event(self, event): + """ + event in ['All' ,'FallbackError', 'BreakGraphError', 'InnerError'] + """ + self.record_event.append(event) + + def add(self, file, line, coname=None, offset=None): + log(1, f"add breakpoint at {file}:{line}\n") + self.breakpoints.add(Breakpoint(file, line, coname, offset)) + + def addn(self, *lines): + """ + called inside a executor. add a list of line number in current file. + """ + if not isinstance(lines, (list, tuple)): + lines = [lines] + for line in lines: + file = self.cur_exe._code.co_filename + self.add(file, line) + + def clear(self): + self.breakpoints.clear() + + def hit(self, file, line, co_name, offset): + if Breakpoint(file, line, None, None) in self.breakpoints: + return True + if Breakpoint(file, line, co_name, offset) in self.breakpoints: + return True + return False + + def locate(self, exe): + for i, _e in enumerate(self.executors): + if _e is exe: + self.activate = i + return + raise RuntimeError("Not found executor.") + + def up(self): + if self.activate == 0: + return + self.activate -= 1 + print("current function is: ", self.cur_exe._code.co_name) + + def down(self): + if self.activate >= len(self.executors) - 1: + return + self.activate += 1 + print("current function is: ", self.cur_exe._code.co_name) + + def opcode(self, cur_exe=None): + if cur_exe is None: + cur_exe = self.cur_exe + instr = cur_exe._instructions[cur_exe._lasti - 1] + message = f"[Translate {cur_exe}]: (line {cur_exe._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {cur_exe._stack}\n" + return message + + def bt(self): + """ + display all inline calls: backtrace. + """ + for exe in self.executors: + lines, _ = inspect.getsourcelines(exe._code) + print( + " " + + exe._code.co_filename + + f"({exe._current_line})" + + f"{exe._code.co_name}()" + ) + print(f"-> {lines[0].strip()}") + print(f"-> {self.opcode(exe)}") + pass + + def on_event(self, event): + if "All" in self.record_event or event in self.record_event: + print("event captured.") + self.activate = len(self.executors) - 1 + breakpoint() + + def _dis_source_code(self): + cur_exe = self.executors[self.activate] + lines, start_line = inspect.getsourcelines(cur_exe._code) + cur_line = cur_exe._current_line + lines[ + cur_line - start_line + 1 : cur_line - start_line + 1 + ] = " ^^^^^ HERE \n" + print("\033[31mSource Code is: \033[0m") + print("".join(lines)) + + def dis(self, range=5): + """ + display all instruction code and source code. + """ + print("displaying debug info...") + cur_exe = self.cur_exe + print(self._dis_source_code()) + + print(f"\n{cur_exe._code}") + lasti = cur_exe._lasti + lines = instrs_info(cur_exe._instructions, lasti - 1, range) + print("\n".join(lines)) + + @property + def cur_exe(self): + exe = self.executors[self.activate] + return exe + + def sir(self): + """ + display sir in a page. + """ + print("displaying sir...") + self.cur_exe.print_sir() + + def pe(self, e): + """ + print exception. + """ + lines = traceback.format_tb(e.__traceback__) + print("".join(lines)) + + +def add_breakpoint(file, line, co_name=None, offset=None): + BM.add(file, line, co_name, offset) + + +def add_event(event): + BM.add_event(event) + + +BM = BreakpointManager() diff --git a/python/paddle/jit/sot/opcode_translator/custom_code.py b/python/paddle/jit/sot/opcode_translator/custom_code.py new file mode 100644 index 00000000000000..da674fb673170a --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/custom_code.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import types +from typing import NamedTuple + + +class CustomCode(NamedTuple): + code: types.CodeType | None + disable_eval_frame: bool diff --git a/python/paddle/jit/sot/opcode_translator/executor/__init__.py b/python/paddle/jit/sot/opcode_translator/executor/__init__.py new file mode 100644 index 00000000000000..4d9db28d227077 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import variable_dispatch # noqa: F401 diff --git a/python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py b/python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py new file mode 100644 index 00000000000000..9b00dcde0462b4 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file stores the customed function that will be called by the dispatch mechanism. + +from ...utils import BreakGraphError, FallbackError + + +def raise_break_graph_fn(*args, **kwarg): + raise BreakGraphError("raise by raise_break_graph_fn.") + + +def raise_not_implement_fn(*args, **kwarg): + raise FallbackError("raise by raise_break_graph_fn.") + + +# just a function for operator.in +def operator_in(left, right): + return left in right + + +def operator_not_in(left, right): + return left not in right + + +def operator_exception_match(left, right): + pass + + +def operator_BAD(left, right): + pass + + +def operator_is_none(val): + pass + + +def operator_is_not_none(val): + pass + + +def tensor_numel(x): + pass diff --git a/python/paddle/jit/sot/opcode_translator/executor/dispatcher.py b/python/paddle/jit/sot/opcode_translator/executor/dispatcher.py new file mode 100644 index 00000000000000..315066f27e820c --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/dispatcher.py @@ -0,0 +1,294 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import inspect +import operator +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, TypeVar + +from ...utils import InnerError, NameGenerator, hashable + +if TYPE_CHECKING: + T = TypeVar("T") + Args = Tuple[T, ...] + Kwargs = Dict[str, T] + + +def format_type(type_: type[Any] | tuple[type[Any], ...]) -> str: + if not isinstance(type_, tuple): + type_ = (type_,) + return " | ".join([t.__name__ for t in type_]) + + +def format_param(param: Parameter) -> str: + kind = param.kind + # TODO: support VAR_KEYWORD + if kind == inspect.Parameter.VAR_POSITIONAL: + return f"*{format_type(param.type)}" + else: + return format_type(param.type) + + +def convert_annotation_to_type(type_str: str) -> tuple[type[Any], ...]: + """ + Convert type annotation to runtime value. Because we are using :pep:`563` + to use the future annotation syntax, we cannot use `get_type_hints <https://docs.python.org/3.8/library/typing.html#typing.get_type_hints>`_ + directly. Currently, only the builtins and variables namespaces are supported. + + Returns: + tuple: The converted type. + """ + + import builtins + + from . import variables + + type_str = type_str.strip() + if type_str == "Any": + type_str = "object" + + if "|" in type_str: + return reduce( + operator.add, map(convert_annotation_to_type, type_str.split("|")) + ) + + search_namespaces = [variables, builtins] + for namespace in search_namespaces: + if hasattr(namespace, type_str): + return (getattr(namespace, type_str),) + raise InnerError(f"Cannot find type {type_str} in {search_namespaces}") + + +class Parameter: + name_gen = NameGenerator("param_") + annotation: str + name: str + + def __init__( + self, + annotation: str, + *, + kind: inspect._ParameterKind = inspect.Parameter.POSITIONAL_OR_KEYWORD, + name: str | None = None, + default: Any = inspect._empty, + ): + self.name = name if name is not None else Parameter.name_gen.next() + self.annotation = annotation + self.kind = kind + self.default = default + + def to_parameter(self) -> inspect.Parameter: + return inspect.Parameter( + self.name, + kind=self.kind, + annotation=self.annotation, + default=copy.copy(self.default), + ) + + @cached_property + def type(self) -> tuple[type[Any], ...]: + return convert_annotation_to_type(self.annotation) + + def match_arg(self, arg: Any) -> bool: + # TODO: support VAR_KEYWORD + if self.kind == inspect.Parameter.VAR_POSITIONAL: + is_tuple = isinstance(arg, tuple) + return is_tuple and all(isinstance(a, self.type) for a in arg) + else: + return isinstance(arg, self.type) + + @staticmethod + def from_str(annotation: str) -> Parameter: + return Parameter(annotation) + + @staticmethod + def from_parameter(parameter: inspect.Parameter) -> Parameter: + if parameter.annotation != parameter.empty and not isinstance( + parameter.annotation, str + ): + raise InnerError( + f"Parameter {parameter} has annotation {parameter.annotation} " + "which is not a string. Please add `from __future__ import annotations` " + "to the top of your file." + ) + annotation = ( + parameter.annotation + if parameter.annotation != parameter.empty + else "Any" + ) + + return Parameter( + annotation, + kind=parameter.kind, + name=parameter.name, + default=parameter.default, + ) + + def __repr__(self) -> str: + default_repr = f"= {self.default!r}" + return f"Parameter({', '.join([self.annotation, default_repr])})" + + +def optional(annotation: str, default: Any = None) -> Parameter: + return Parameter(annotation, default=default) + + +class Pattern: + parameters: dict[str, Parameter] + signature: inspect.Signature + + def __init__( + self, + *parameters: Parameter, + ): + self.parameters = { + parameter.name: parameter for parameter in parameters + } + self.signature = inspect.Signature( + [parameter.to_parameter() for parameter in self.parameters.values()] + ) + + def match_inputs(self, /, *args: Any, **kwargs: Any) -> bool: + """ + Match the input parameters of the function. + + Returns: + bool: Whether the input parameters match the pattern. + """ + try: + bound_args = self.signature.bind(*args, **kwargs) + except TypeError: + return False + for arg_name, arg_value in bound_args.arguments.items(): + if arg_name not in self.parameters: + continue + if not self.parameters[arg_name].match_arg(arg_value): + return False + return True + + def __repr__(self) -> str: + types_repr = ", ".join( + [format_param(param) for param in self.parameters.values()] + ) + return f"Pattern({types_repr})" + + +class Dispatcher: + """ + Used for pattern registration and distribution. + + For more design ideas, refer to the `Builtin dispatcher <https://github.com/PaddlePaddle/PaddleSOT/blob/develop/docs/design/builtin-dispatcher.md>`_ for details. + + Examples: + + >>> def builtin_add(a: int, b: int) -> int: + ... ... + ... + >>> Dispatcher.register(builtin_add, ("int", "int"), lambda a, b: a + b) + >>> handler = Dispatcher.dispatch(builtin_add, 1, 2) + >>> handler(1, 2) + 3 + """ + + handlers: dict[ + Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]] + ] = {} + graph: Any = None + + @classmethod + def register( + cls, + fn: Callable[..., Any], + parameters: tuple[str | Parameter, ...], + handler: Callable[..., Any], + ): + """ + Registering function signature. + + Args: + fn: The function to be registered. + parameters: The parameters of the function to be registered. + handler: The handler function. + """ + _parameters = tuple( + Parameter.from_str(parameter) + if isinstance(parameter, str) + else parameter + for parameter in parameters + ) + if fn not in cls.handlers: + cls.handlers[fn] = [] + cls.handlers[fn].append((Pattern(*_parameters), handler)) + + @classmethod + def register_decorator(cls, fn: Callable[..., Any]): + """ + Decorator mode of register, Used to register some complex functions. + + Args: + fn: The function to be registered. + + Examples: + >>> def builtin_add(a: int, b: int) -> int: + ... ... + ... + >>> @Dispatcher.register_decorator(builtin_add) + ... def builtin_add_dispatcher(a: int, b: int) -> int: + ... return a + b + ... + >>> handler = Dispatcher.dispatch(builtin_add, 1, 2) + >>> handler(1, 2) + 3 + """ + + def decorator(handler: Callable[..., Any]): + signature = inspect.signature(handler) + parameters = tuple( + Parameter.from_parameter(parameter) + for parameter in signature.parameters.values() + ) + cls.register(fn, parameters, handler) + + return decorator + + @classmethod + def call(cls, fn, *args, **kwargs): + func = cls.dispatch(fn, *args, **kwargs) + if func is None: + raise InnerError( + f"Cannot find handler for {fn} with args {args} and kwargs {kwargs}" + ) + return func(*args, **kwargs) + + @classmethod + def dispatch( + cls, fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Callable[..., Any] | None: + """ + Find the matching handler from the registered functions. + + Args: + fn: The function to be dispatched. + args: The args of the function. + kwargs: The kwargs of the function. + """ + if not hashable(fn) or fn not in cls.handlers: + return None + for pattern, handler in cls.handlers[fn]: + if pattern.match_inputs(*args, **kwargs): + return handler + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py new file mode 100644 index 00000000000000..67d656f4dcd752 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -0,0 +1,230 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import traceback +import types +from typing import List, Tuple + +from ...profiler import EventGuard, event_register +from ...psdb import NO_FALLBACK_CODES +from ...utils import ( + BreakGraphError, + FallbackError, + InnerError, + Singleton, + is_strict_mode, + log, + log_do, +) +from ..custom_code import CustomCode +from .guard import Guard +from .opcode_executor import OpcodeExecutor, OpcodeExecutorBase +from .pycode_generator import PyCodeGen + +GuardedFunction = Tuple[CustomCode, Guard] +GuardedFunctions = List[GuardedFunction] + +dummy_guard: Guard = lambda frame: True +dummy_guard.expr = "lambda frame: True" +dummy_guard.lambda_expr = "lambda frame: True" + + +@Singleton +class OpcodeExecutorCache: + """ + A singleton class that implements a cache for translated instructions. + This cache is used to store previously translated instructions along with their corresponding guard functions. + + Attributes: + cache (dict): A dictionary that maps code objects to tuples of a cache getter function and a list of guarded functions. + translate_count (int): The count of how many instructions have been translated. It is used to test whether the cache hits. + """ + + MAX_CACHE_SIZE = 20 + cache: dict[types.CodeType, GuardedFunctions] + translate_count: int + + def __init__(self): + self.cache = {} + self.translate_count = 0 + + def clear(self): + """ + Clears the cache and resets the translate count. + """ + self.cache.clear() + self.translate_count = 0 + + def __call__(self, frame: types.FrameType, **kwargs) -> CustomCode: + code: types.CodeType = frame.f_code + if code not in self.cache: + log(2, f"[Cache]: Firstly call {code}\n") + new_custom_code, guard_fn = self.translate(frame, **kwargs) + self.cache[code] = [(new_custom_code, guard_fn)] + return new_custom_code + guarded_fns = self.cache[code] + return self.lookup(frame, guarded_fns, **kwargs) + + @event_register("lookup") + def lookup( + self, frame: types.FrameType, guarded_fns: GuardedFunctions, **kwargs + ) -> CustomCode: + """ + Looks up the cache for a matching code object and returns a custom code object if a matching guard function is found, otherwise None. + + Args: + frame (types.FrameType): The frame whose code object needs to be looked up in the cache. + guarded_fns (GuardedFunctions): The list of guarded functions associated with the code object. + + Returns: + CustomCode | None: The custom code object if a matching guard function is found, otherwise None. + """ + + if len(guarded_fns) >= self.MAX_CACHE_SIZE: + log(2, "[Cache]: Exceed max cache size, skip it\n") + return CustomCode(None, False) + + for custom_code, guard_fn in guarded_fns: + try: + with EventGuard("try guard"): + guard_result = guard_fn(frame) + if guard_result: + log( + 2, + f"[Cache]: Cache hit, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n", + ) + return custom_code + else: + log_do( + 4, + self.analyse_guard_global_object(guard_fn), + ) + log( + 2, + f"[Cache]: Cache miss, Guard is \n{getattr(guard_fn, 'expr', 'None')}\n", + ) + log_do( + 2, + self.analyse_guard_error(guard_fn, frame), + ) + except Exception as e: + log(2, f"[Cache]: Guard function error: {e}\n") + continue + + log(2, "[Cache]: all guards missed\n") + new_custom_code, guard_fn = self.translate(frame, **kwargs) + guarded_fns.append((new_custom_code, guard_fn)) + return new_custom_code + + def translate( + self, frame: types.FrameType, **kwargs + ) -> tuple[CustomCode, Guard]: + """ + Translates the given frame's code object and returns the cache getter function and a guarded function for the translated code object. + + Args: + frame (types.FrameType): The frame whose code object needs to be translated. + + Returns: + tuple[CustomCode, Guard]: The cache getter function and a guarded function for the translated code object. + """ + code: types.CodeType = frame.f_code + self.translate_count += 1 + custom_new_code, guard_fn = start_translate(frame, **kwargs) + return custom_new_code, guard_fn + + def analyse_guard_global_object(self, guard_fn): + def inner(): + for key in guard_fn.__globals__.keys(): + if key.startswith("__object"): + print( + f"[Cache] meet global object: {key} : {guard_fn.__globals__[key]}", + ) + + return inner + + def analyse_guard_error(self, guard_fn, frame): + def inner(): + guard_expr = guard_fn.lambda_expr + lambda_head = "lambda frame: " + guard_expr = guard_expr.replace(lambda_head, "") + guards = guard_expr.split(" and ") + for guard_str in guards: + guard = eval(lambda_head + guard_str, guard_fn.__globals__) + result = False + try: + result = guard(frame) + except Exception as e: + print( + f"[Cache]: skip checking {guard_str}\n because error occured {e}" + ) + if result is False: + print(f"[Cache]: missed at {guard_str}") + return + print("[Cache]: missed guard not found.") + + return inner + + +def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: + """ + Starts the translation process for the given frame and returns the translated code object and its guard function, or None if translation fails. + + Args: + frame: The frame to be translated. + + Returns: + GuardedFunction | None: The translated code object and its guard function, or None if translation fails. + """ + simulator = OpcodeExecutor(frame, **kwargs) + try: + new_custom_code, guard_fn = simulator.transform() + return new_custom_code, guard_fn + # TODO(zrr1999): InnerError maybe place before (FallbackError, BreakGraphError) + # TODO(0x45f): handle BreakGraphError to trigger fallback + except BreakGraphError as e: + raise RuntimeError( + f"Found BreakGraphError raised, it should not be catch at start_translate!\n{e}" + ) + except FallbackError as e: + if simulator._code in NO_FALLBACK_CODES: + raise InnerError( + f"{simulator._code.co_name} should not fallback, but got '{e}'" + ) + # if disable_eval_frame is True, it means we want fallback to speedup rather than error occured + if is_strict_mode() and e.disable_eval_frame is False: + raise + log( + 2, + f"Unsupport Frame is {frame.f_code}, error message is: \n" + + "".join(traceback.format_exception(type(e), e, e.__traceback__)), + ) + + # NOTE: If resume fn need fallback, we should replace NullVariable using NULL otherwise will fail to run + py_codegen = PyCodeGen(frame) + new_code = py_codegen.replace_null_variable() + # simulation not complete, not sure whether this code has sir, set disable_eval_frame = False + guard_fn = ( + dummy_guard if e.disable_eval_frame is False else simulator.guard_fn + ) + return ( + CustomCode(new_code, e.disable_eval_frame), + guard_fn, + ) + except Exception as e: + raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e + finally: + simulator.cleanup() diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py new file mode 100644 index 00000000000000..0859ecfec46b9a --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -0,0 +1,684 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is specifically used to handle the problem +# of generating a Graph from a linear function call. + +from __future__ import annotations + +import builtins +import inspect +from collections import namedtuple +from copy import deepcopy +from functools import cached_property +from typing import Any, Callable + +from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo +from ...profiler import EventGuard, event_register +from ...symbolic.statement_ir import Symbol +from ...symbolic.symbolic_context import SymbolicTraceContext +from ...utils import ( + NameGenerator, + OrderedSet, + inner_error_default_handler, + is_inplace_api, + is_paddle_api, + log, + log_do, + map_if, + show_trackers, + tmp_name_guard, +) +from .guard import Guard, StringifyExpression, make_guard +from .mutable_data import MutationDel, MutationNew, MutationSet +from .pycode_generator import PyCodeGen +from .side_effects import ( + DictSideEffectRestorer, + GlobalDelSideEffectRestorer, + GlobalSetSideEffectRestorer, + ListSideEffectRestorer, + ObjDelSideEffectRestorer, + ObjSetSideEffectRestorer, + SideEffectRestorer, + SideEffects, +) +from .tracker import BuiltinTracker, DummyTracker +from .variables import ( + DictVariable, + GlobalVariable, + ListVariable, + NullVariable, + PaddleLayerVariable, + TensorVariable, + VariableBase, + VariableFactory, + find_traceable_vars, + map_variables, +) + + +def convert_to_meta(inputs: Any): + """ + Convert the input variables to meta if it is TensorVariable. + """ + + def func(x): + if isinstance(x, TensorVariable): + return x.meta + if isinstance(x, VariableBase): + return x.get_py_value() + return x + + return map_variables(func, inputs) + + +def convert_to_symbol(inputs: Any): + """ + Convert the input variables to symbol if it can be symbolic. + """ + + def func(x): + if isinstance(x, (TensorVariable, PaddleLayerVariable)): + return x.get_symbol() + if isinstance(x, VariableBase): + return x.get_py_value() + return x + + return map_variables(func, inputs) + + +class FunctionGraph: + """ + A Graph representation corresponding to each FunctionFrame + The input binding diagram containing the current call represents three parts of output settings, + This Graph can be compiled as a f_locals dependency function which produce the same outputs. + """ + + OUT_VAR_PREFIX = "___SIR_out_" + Memo = namedtuple( + "function_graph_memo", + [ + 'inner_out', + 'input_variables', + "stmt_ir", + "global_guards", + "side_effects_state", + "print_variables", + "inplace_tensors", + ], + ) + + def __init__(self, frame, **kwargs): + self.sir_ctx = SymbolicTraceContext() + self.inner_out = set() + self.input_variables = [] # Store variables required within a function + self.pycode_gen = PyCodeGen(frame, disable_eval_frame=True) + self.side_effects = SideEffects() + self._global_guarded_variables: OrderedSet[VariableBase] = OrderedSet() + self._print_variables = [] + self._inplace_tensors = OrderedSet() + self.build_strategy = kwargs.get('build_strategy', None) + self._kwargs = kwargs + + @cached_property + def _builtins(self): + builtins_ = {} + # prepare builtins + for name, value in builtins.__dict__.items(): + builtins_[name] = VariableFactory.from_value( + value, self, BuiltinTracker(name), debug_name=name + ) + return builtins_ + + def add_print_variables(self, variable): + """ + Used to support psdb_print + """ + self._print_variables.append(variable) + + def add_inplace_tensors(self, variable): + """ + Used to support psdb_print + """ + self._inplace_tensors.add(variable) + + def need_add_input(self, var): + """ + Determine if it is the input of graph. + + Args: + var: The input variable. + + """ + if var.id in self.inner_out: + return False + for v in self.input_variables: + if v.id == var.id: + return False + return True + + def save_memo(self) -> FunctionGraph.Memo: + """ + Save the state of the current FunctionGraph, for future state recovery, it is used for state recovery during inline call error reporting + + NOTE: + Why don't use __deepcopy__, because memo is not a deepcopy, i.e inner_out is only a shallow copy, SIR is a deepcopy. + """ + saved_stmt_ir = deepcopy(self.sir_ctx.TOS) + return FunctionGraph.Memo( + inner_out=set(self.inner_out), + input_variables=list(self.input_variables), + stmt_ir=saved_stmt_ir, + global_guards=OrderedSet(self._global_guarded_variables), + side_effects_state=self.side_effects.get_state(), + print_variables=list(self._print_variables), + inplace_tensors=OrderedSet(self._inplace_tensors), + ) + + def restore_memo(self, memo: FunctionGraph.Memo): + """ + Restore the state of graph to memo. + + Args: + memo: Previously recorded memo + + """ + self.inner_out = memo.inner_out + self.input_variables = memo.input_variables + self.sir_ctx.replace_TOS(memo.stmt_ir) + self._global_guarded_variables = memo.global_guards + self.side_effects.restore_state(memo.side_effects_state) + self._print_variables = memo.print_variables + self._inplace_tensors = memo.inplace_tensors + + def collect_input_variables(self, inputs: list[VariableBase]): + """ + Variables required within the method + + Args: + inputs: Required VariableBase + """ + + def collect(inp): + if isinstance(inp, VariableBase) and self.need_add_input(inp): + self.input_variables.append(inp) + + map_variables( + collect, + inputs, + ) + + @property + @event_register("guard_fn") + def guard_fn(self) -> Guard: + with tmp_name_guard(): + guards = [] + with EventGuard( + "guard_fn: find vars and make stringify guard", event_level=1 + ): + for variable in find_traceable_vars( + self.input_variables + list(self._global_guarded_variables) + ): + guards.extend(variable.make_stringify_guard()) + + guards = OrderedSet(guards) + + for guard in guards: + assert isinstance( + guard, StringifyExpression + ), "guard must be StringifyExpression." + + return make_guard(guards) + + def start_compile_with_name_store(self, ret_vars, to_store_vars): + class VariableLoader: + def __init__(self, index_for_load, pycode_gen): + self._index_for_load = index_for_load + self._pycode_gen: PyCodeGen = pycode_gen + + def load(self, var, allow_push_null=True): + if isinstance(var, NullVariable): + if allow_push_null: + var.reconstruct(self._pycode_gen) + else: + # Avoid passing NULL as a parameter to the resume function + self._pycode_gen.gen_load_null_variable() + return + self._pycode_gen.gen_load_fast(self._index_for_load[var.id]) + + # var_id -> local_name mapping + index_for_load = {} + to_store_vars = list( + filter(lambda x: not isinstance(x, NullVariable), to_store_vars) + ) + self.start_compile(*(ret_vars + to_store_vars)) + name_gen = NameGenerator("__start_compile_saved_") + for var in to_store_vars: + index_for_load[var.id] = name_gen.next() + + def _log_fn(): + print( + f"[StartCompile] saved var: {index_for_load[var.id]} = ", + var, + ) + + log_do(4, _log_fn) + + for var in to_store_vars[::-1]: + self.pycode_gen.gen_store_fast(index_for_load[var.id]) + return VariableLoader(index_for_load, self.pycode_gen) + + @event_register("start_compile", event_level=2) + def start_compile(self, *ret_vars: VariableBase): + """ + Generate bytecode based on the information collected by the simulation execution. + + This consists of the following steps: + - Compile the FunctionGraph into a dy2st StaticFunction and load it in the generated bytecode + - Load the group network input + - Calling the generated dy2st StaticFunction + - Restore the side effects + - Restore the output + - Return the top of the stack + """ + from ..breakpoint import BreakpointManager + + BreakpointManager().on_event("start_compile") + + ret_items = [ + ret_item + for ret_var in ret_vars + for ret_item in ret_var.flatten_items() + ] + + tensor_items = self._find_tensor_outputs(ret_items) + compiled_fn, statment_ir = self.sir_ctx.compile_fn( + [Symbol(tensor_var.var_name) for tensor_var in tensor_items], + **self._kwargs, + ) + input_names = statment_ir.inputs + compiled_fn_name = f"__compiled_fn_{statment_ir.name}" + # prepare function and inputs + self.pycode_gen.gen_load_object(compiled_fn, compiled_fn_name) + for name in input_names: + found = False + for variable in self.input_variables: + if ( + isinstance(variable, TensorVariable) + and variable.get_symbol().name == name + ): + variable.tracker.gen_instructions(self.pycode_gen) + found = True + break + assert found, f"can't find input {name} in SIR." + # Pack all args into a tuple, because we don't support *args now. + self.pycode_gen.gen_build_tuple(count=len(input_names)) + # call the compiled_fn + self.pycode_gen.gen_call_function(argc=1) + + # Store outputs to f_locals + self.pycode_gen.gen_unpack_sequence(count=len(tensor_items)) + for tensor_var in tensor_items: + self.pycode_gen.gen_store_fast(tensor_var.out_var_name) + # restore the outputs. + for ret_var in ret_vars: + ret_var.reconstruct(self.pycode_gen) + + # deal side effect + self.restore_inplace_tensor(self._inplace_tensors) + self.restore_print_stmts(self._print_variables) + self.restore_side_effects(self.side_effects.proxy_variables) + self.pycode_gen.gen_enable_eval_frame() + + tracker_output_path = show_trackers() + if tracker_output_path: + from .tracker_viewer import view_tracker + + view_tracker(list(ret_vars), tracker_output_path, format="png") + + def call_paddle_api( + self, + func: Callable[..., Any], + *args: VariableBase, + **kwargs: VariableBase, + ): + """ + Record Paddle Networking API to SIR + + Args: + func: paddle api + """ + assert is_paddle_api(func) + # not fallback api, start symbolic trace. + # TODO(xiokgun): may have python buildin object inside metas. + # TODO(xiokgun): 4 kinds of python arguments. support it !! + log(3, f"call paddle.api : {func.__name__}", "\n") + + def message_handler(*args, **kwargs): + return f"Call paddle_api error: {func.__name__}, may be not a operator api ?" + + return inner_error_default_handler(self.symbolic_call, message_handler)( + InferMetaCache(), self.sir_ctx.call_API, func, *args, **kwargs + ) + + def call_tensor_method( + self, method_name: str, *args: VariableBase, **kwargs + ): + """ + call tensor method, start symbolic trace. + + Args: + method_name: tensor method name + """ + + def message_handler(*args, **kwargs): + return f"Call tensor_method error: Tensor.{method_name}, may be not a valid operator api ?" + + return inner_error_default_handler(self.symbolic_call, message_handler)( + InferMetaCache(), + self.sir_ctx.call_METHOD, + method_name, + *args, + **kwargs, + ) + + @staticmethod + def get_opcode_executor_stack(): + # NOTE: only for debug. + # dependent on OpcodeExecutor. + from .opcode_executor import OpcodeExecutorBase + + if len(OpcodeExecutorBase.call_stack) == 0: + # In test case, we can meet this senario. + return [] + current_executor = OpcodeExecutorBase.call_stack[-1] + current_line = current_executor._current_line + filename = current_executor._code.co_filename + source_lines, start_line = inspect.getsourcelines( + current_executor._code + ) + # TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph, + # We need to find a way to fix this. + line_idx = min(current_line - start_line, len(source_lines) - 1) + code_line = source_lines[line_idx] + stack = [] + stack.append( + ' File "{}", line {}, in {}'.format( + filename, + current_line, + current_executor._code.co_name, + ) + ) + stack.append(f' {code_line}') + return stack + + def call_layer( + self, + layer: PaddleLayerVariable, + *args: VariableBase, + **kwargs: VariableBase, + ): + """ + call paddle layer, start symbolic trace. + + Args: + layer: paddle layer + """ + + def infer_meta_fn(layer, *metas, **kwmetas): + metas = LayerInferMetaCache()(layer.value, *metas, **kwmetas) + return metas + + def compute_fn(layer, inputs, outputs, stacks): + self.sir_ctx.call_LAYER( + layer.value, + inputs=inputs, + outputs=outputs, + stacks=stacks, + ) + + def message_handler(*args, **kwargs): + return f"Call paddle layer error: {layer}, may be not a valid paddle layer ?" + + return inner_error_default_handler(self.symbolic_call, message_handler)( + infer_meta_fn, compute_fn, layer, *args, **kwargs + ) + + def symbolic_call(self, infer_meta_fn, compute_fn, func, *args, **kwargs): + """ + Using infer_meta_fn and compute_fn convert func to symbolic function. + + Args: + infer_meta_fn: function for infer meta, (func, metas, kwmetas) -> output_metas + compute_fn : function for sir compile, (func, input_symbols, outputs_symbols) -> None + func : symbolic function + """ + self.collect_input_variables(list(args)) + self.collect_input_variables(list(kwargs.values())) + metas = convert_to_meta(args) + kwmetas = convert_to_meta(kwargs) + + out_metas = infer_meta_fn(func, *metas, **kwmetas) + inputs_symbols = ( + convert_to_symbol(args), + convert_to_symbol(kwargs), + ) + log(3, f" inputs : {inputs_symbols}", "\n") + + outputs = map_if( + out_metas, + pred=lambda x: isinstance(x, MetaInfo), + true_fn=lambda x: TensorVariable( + x, + self, + tracker=DummyTracker(list(args) + list(kwargs.values())), + ), + false_fn=lambda x: x, + ) + stmt_stacks = [] + log_do( + 3, + lambda: stmt_stacks.extend( + FunctionGraph.get_opcode_executor_stack() + ), + ) + if outputs is not None: + if is_inplace_api(func): + # if we want to use a non-inplace api (static api) to replace an inplace behavior (in simulation) + # just set it back in SIR, and return outputs to replace tensor meta (it might changes?) + # in this case, the output will not exactly be used + compute_fn( + func, + inputs_symbols, + convert_to_symbol(args[0]), + stmt_stacks, + ) + else: + compute_fn( + func, + inputs_symbols, + convert_to_symbol(outputs), + stmt_stacks, + ) # symbolic only contain symbols. + self._put_inner(outputs) + return VariableFactory.from_value( + outputs, self, DummyTracker(list(args) + list(kwargs.values())) + ) + else: + return None + + def _put_inner(self, vars: VariableBase): + """ + put inner variable to inner_out + """ + map_if( + vars, + pred=lambda x: isinstance(x, VariableBase), + true_fn=lambda x: self.inner_out.add(x.id), + false_fn=lambda x: None, + ) + + def add_global_guarded_variable(self, variable: VariableBase): + """ + Add variable to global guarded variable + """ + self._global_guarded_variables.add(variable) + + def remove_global_guarded_variable(self, variable: VariableBase): + """ + Remove variable to global guarded variable + """ + if variable in self._global_guarded_variables: + self._global_guarded_variables.remove(variable) + + def _find_tensor_outputs( + self, outputs: list[VariableBase] + ) -> OrderedSet[TensorVariable]: + """ + Return all TensorVariable. find TensorVariables participating in networking from the output Variables + + Args: + outputs: output variables + """ + output_tensors: OrderedSet[TensorVariable] = OrderedSet() + # Find Tensor Variables from outputs. + for output in outputs: + if isinstance(output.tracker, DummyTracker): + if isinstance(output, TensorVariable): + output_tensors.add(output) + else: + # Guard output that can not be traced. + self.add_global_guarded_variable(output) + # Find Tensor Variables from side effects Variables. + for side_effect_var in self.side_effects.proxy_variables: + if isinstance(side_effect_var, (ListVariable, DictVariable)): + for var in side_effect_var.flatten_items(): + if ( + isinstance(var.tracker, DummyTracker) + and isinstance(var, TensorVariable) + and side_effect_var.tracker.is_traceable() + ): + output_tensors.add(var) + else: + if isinstance(side_effect_var, GlobalVariable): + proxy_records = side_effect_var.proxy.records + elif side_effect_var.tracker.is_traceable(): + # for attr side effect + proxy_records = side_effect_var.attr_proxy.records + else: + continue + for record in proxy_records: + if isinstance(record, (MutationSet, MutationNew)): + for var in record.value.flatten_items(): + if isinstance( + var.tracker, DummyTracker + ) and isinstance(var, TensorVariable): + output_tensors.add(var) + # Find Tensor in print_stmts + for print_stmt in self._print_variables: + for var in print_stmt.flatten_items(): + if isinstance(var.tracker, DummyTracker) and isinstance( + var, TensorVariable + ): + output_tensors.add(var) + + # add inplace tensors into output tensors. + for inplace_tensor in self._inplace_tensors: + output_tensors.add(inplace_tensor) + + return output_tensors + + def restore_print_stmts(self, variables: list[VariableBase]): + for var in variables: + var.reconstruct( + self.pycode_gen, + use_tracker=False, + add_to_global_guarded_vars=False, + ) + + def restore_inplace_tensor(self, variables: list[VariableBase]): + for var in variables: + if not var.tracker.is_traceable(): + continue + var.reconstruct( + self.pycode_gen, + use_tracker=True, + add_to_global_guarded_vars=False, + ) + self.pycode_gen.gen_load_method( + "_inplace_assign" + ) # NOTE: paddle related logic. + var.reconstruct( + self.pycode_gen, + use_tracker=False, + add_to_global_guarded_vars=True, + ) + self.pycode_gen.gen_call_method(1) + self.pycode_gen.gen_pop_top() + + def restore_side_effects(self, variables: list[VariableBase]): + """ + Generate side effect recovery code for variables with side effects + + Args: + variables: Variables that may have side effects. + """ + restorers: list[SideEffectRestorer] = [] + + for var in variables: + # skip inner variables + if not var.tracker.is_traceable() and not isinstance( + var, GlobalVariable + ): + continue + if isinstance(var, DictVariable): + restorers.append(DictSideEffectRestorer(var)) + elif isinstance(var, ListVariable): + restorers.append(ListSideEffectRestorer(var)) + else: + if isinstance(var, GlobalVariable): + for record in var.proxy.records[::-1]: + if isinstance(record, (MutationSet, MutationNew)): + restorers.append( + GlobalSetSideEffectRestorer( + record.key, + record.value, + ) + ) + elif isinstance(record, MutationDel): + restorers.append( + GlobalDelSideEffectRestorer(record.key) + ) + else: + for record in var.attr_proxy.records[::-1]: + if isinstance(record, (MutationSet, MutationNew)): + restorers.append( + ObjSetSideEffectRestorer( + var, + record.key, + record.value, + ) + ) + elif isinstance(record, MutationDel): + restorers.append( + ObjDelSideEffectRestorer( + var, + record.key, + ) + ) + + for restorer in restorers: + restorer.pre_gen(self.pycode_gen) + for restorer in restorers[::-1]: + restorer.post_gen(self.pycode_gen) diff --git a/python/paddle/jit/sot/opcode_translator/executor/guard.py b/python/paddle/jit/sot/opcode_translator/executor/guard.py new file mode 100644 index 00000000000000..b839c064f407da --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/guard.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import types +import weakref +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from ...profiler import EventGuard +from ...utils import InnerError, current_tmp_name_records, log, log_do + +Guard = Callable[[types.FrameType], bool] + +if TYPE_CHECKING: + from .variables import VariableBase + + CheckGuardInputT = TypeVar("CheckGuardInputT", bound=VariableBase) + +# NOTE(SigureMo): [How to write Stringify Guard?] +# 1. we should capture free variables manually, the string cannot capture free +# variables automatically. +# 2. Be aware that the comparison logic before and after stringify may be different. +# 3. we should compute as much as possible at "compile time" and encode the +# computation in the Guard string, rather than passing it to runtime to minimize +# runtime overhead. + + +class StringifyExpression: + """ + Used to store string based expressions for generating Guard. + """ + + def __init__(self, str_expr, sub_exprs, free_vars): + expr = str_expr.format(*[arg.expr for arg in sub_exprs]) + self.expr = current_tmp_name_records().add_tmp_var(expr) + self.debug_expr = str_expr.format( + *[arg.debug_expr for arg in sub_exprs] + ) + self.free_vars = free_vars + + def __post_init__(self): + self.check_expr(self.expr) + + def check_expr(self, expr: str): + try: + pass + # ast.parse(expr) # TODO(xiongkun): too slow + except SyntaxError as e: + raise InnerError(f"Invalid expression: {expr}") from e + + def __hash__(self): + if self.free_vars: + return hash((self.debug_expr, id(self))) + else: + return hash(self.debug_expr) + + +def union_free_vars(*free_vars: dict[str, Any]): + return {k: v for d in free_vars for k, v in d.items()} + + +def make_guard(stringify_guards: list[StringifyExpression]) -> Guard: + """ + Make a guard from a list of StringifyExpression. + + For more design ideas, refer to the `Stringify guard <https://github.com/PaddlePaddle/PaddleSOT/blob/develop/docs/design/stringify-guard.md>`_ for details. + + Args: + stringify_guards: a list of StringifyExpression. + """ + with EventGuard("make_guard"): + num_guards = len(stringify_guards) + if not num_guards: + guard = lambda frame: True + guard.expr = "lambda frame: True" + return guard + + def analyse_expresions(stringify_exprs, tmp_names): + func_string = "def built_guard_fn(frame):\n" + lambda_string = "lambda frame: " + free_vars = {} + + for k, v in tmp_names.items(): + func_string += f" {v} = {k}\n" + + func_result = "" + for str_expr in stringify_exprs: + func_result += str_expr.expr + " and " + lambda_string += str_expr.debug_expr + " and " + free_vars = union_free_vars(free_vars, str_expr.free_vars) + + func_string += f" return {func_result[:-5]}" + + return func_string, free_vars, lambda_string[:-5] + + ( + func_string, + free_vars, + lambda_string, + ) = analyse_expresions( + stringify_guards, current_tmp_name_records().tmp_names_record + ) + + exec( + func_string, + free_vars, + ) + + guard = free_vars['built_guard_fn'] + log(3, f"[Guard]: {lambda_string}\n") + guard.lambda_expr = lambda_string + guard.expr = func_string + assert callable(guard), "guard must be callable." + + return guard + + +def support_weak_ref(obj): + if isinstance(obj, types.FunctionType): + return True + return False + + +def check_guard( + fn: Callable[[CheckGuardInputT], list[StringifyExpression]] +) -> Callable[[CheckGuardInputT], list[StringifyExpression]]: + def wrapper(self: CheckGuardInputT) -> list[StringifyExpression]: + assert ( + self.tracker.is_traceable() + ), "Cannot make guard from a non-tracable guard variable." + + def guard_log(): + frame_value_tracer = self.tracker.trace_value_from_frame() + print( + f"[Guard]: guard_fn for {self}, tracker={self.tracker.__class__.__name__}, value={frame_value_tracer.expr}" + ) + + log_do(4, guard_log) + return fn(self) + + return wrapper + + +@check_guard +def object_equal_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + obj_free_var_name = f"__{self.id}" + weak_ref_obj = self.get_py_value() + if support_weak_ref(weak_ref_obj): + weak_ref_obj = weakref.ref(self.get_py_value()) + return [ + StringifyExpression( + f"{obj_free_var_name}() is not None and {{}} == {obj_free_var_name}()", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + {obj_free_var_name: weak_ref_obj}, + ), + ) + ] + return [ + StringifyExpression( + f"{{}} == {obj_free_var_name}", + [frame_value_tracer], + union_free_vars( + frame_value_tracer.free_vars, + {obj_free_var_name: self.get_py_value()}, + ), + ) + ] diff --git a/python/paddle/jit/sot/opcode_translator/executor/instr_flag.py b/python/paddle/jit/sot/opcode_translator/executor/instr_flag.py new file mode 100644 index 00000000000000..1dd795439d4597 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/instr_flag.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flags for instructions + + +class FORMAT_VALUE_FLAG: + FVC_MASK = 0x3 + FVC_NONE = 0x0 + FVC_STR = 0x1 + FVC_REPR = 0x2 + FVC_ASCII = 0x3 + FVS_MASK = 0x4 + FVS_HAVE_SPEC = 0x4 + + +class MAKE_FUNCTION_FLAG: + MF_HAS_CLOSURE = 0x08 + MF_HAS_ANNOTATION = 0x04 + MF_HAS_KWDEFAULTS = 0x02 + MF_HAS_DEFAULTS = 0x01 + + +class CALL_FUNCTION_EX_FLAG: + CFE_HAS_KWARGS = 0x01 diff --git a/python/paddle/jit/sot/opcode_translator/executor/mutable_data.py b/python/paddle/jit/sot/opcode_translator/executor/mutable_data.py new file mode 100644 index 00000000000000..d6bda43d42ef4e --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/mutable_data.py @@ -0,0 +1,289 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar + +if TYPE_CHECKING: + from typing_extensions import Concatenate, ParamSpec, TypeAlias + + P = ParamSpec("P") + R = TypeVar("R") + + MutableDataT = TypeVar("MutableDataT", bound="MutableData") + DataGetter: TypeAlias = Callable[[MutableDataT, Any], Any] + +InnerMutableDataT = TypeVar( + "InnerMutableDataT", bound="dict[str, Any] | list[Any]" +) + + +class Mutation: + ABBR: str + + +class MutationSet(Mutation): + """ + Setting a value. + This mutation is used for MutableDictLikeData and MutableListLikeData. + """ + + ABBR = "S" + + def __init__(self, key, value): + self.key = key + self.value = value + + def __repr__(self): + return f"MutationSet({self.key}, {self.value})" + + +class MutationDel(Mutation): + """ + Deleting a value. + This mutation is used for MutableDictLikeData and MutableListLikeData. + """ + + ABBR = "D" + + def __init__(self, key): + self.key = key + + def __repr__(self): + return f"MutationDel({self.key})" + + +class MutationNew(Mutation): + """ + Adding a new value. + This mutation is only used for MutableDictLikeData. + """ + + ABBR = "N" + + def __init__(self, key, value): + self.key = key + self.value = value + + def __repr__(self): + return f"MutationNew({self.key}, {self.value})" + + +class MutationInsert(Mutation): + """ + Inserting a value. + This mutation is only used for MutableListLikeData. + """ + + ABBR = "I" + + def __init__(self, index, value): + self.index = index + self.value = value + + def __repr__(self): + return f"MutationInsert({self.index}, {self.value})" + + +class MutationPermutate(Mutation): + """ + Permutating all the values. + This mutation is only used for MutableListLikeData. + """ + + ABBR = "P" + + def __init__(self, permutation): + self.permutation = permutation + + def __repr__(self): + return f"MutationPermutate({self.permutation})" + + +def record_mutation( + mutation_fn: Callable[Concatenate[MutableDataT, P], Mutation] +) -> Callable[Concatenate[MutableDataT, P], None]: + def wrapper(self, *args: P.args, **kwargs: P.kwargs): + mutation = mutation_fn(self, *args, **kwargs) + self.records.append(mutation) + + return wrapper + + +class MutableData(Generic[InnerMutableDataT]): + """ + An intermediate data structure between data and variable, it records all the mutations. + """ + + read_cache: InnerMutableDataT + + class Empty: + def __repr__(self): + return "Empty()" + + def __init__(self, data: Any, getter: DataGetter): + self.original_data = data + self.getter = getter + self.records: list[Mutation] = [] + + def is_empty(self, value): + return isinstance(value, MutableData.Empty) + + @property + def version(self): + return len(self.records) + + @property + def has_changed(self): + return self.version != 0 + + def rollback(self, version: int): + assert version <= self.version + self.records[:] = self.records[:version] + + def get(self, key): + raise NotImplementedError() + + def set(self, key, value): + raise NotImplementedError() + + def apply(self, mutation: Mutation, write_cache: InnerMutableDataT): + raise NotImplementedError() + + def reproduce(self, version: int | None = None) -> InnerMutableDataT: + if version is None: + version = self.version + write_cache = self.read_cache.copy() + for mutation in self.records[:version]: + self.apply(mutation, write_cache) + return write_cache + + def __repr__(self) -> str: + records_abbrs = "".join([mutation.ABBR for mutation in self.records]) + return f"{self.__class__.__name__}({records_abbrs})" + + +class MutableDictLikeData(MutableData["dict[str, Any]"]): + def __init__(self, data: Any, getter: DataGetter): + super().__init__(data, getter) + self.read_cache = {} + + def clear_read_cache(self): + self.read_cache.clear() + + def get(self, key: Any): + # TODO(SigureMo): Optimize performance of this. + write_cache = self.reproduce(self.version) + if key not in write_cache: + self.read_cache[key] = self.getter(self, key) + return self.reproduce(self.version)[key] + + def get_all(self): + original_keys = list(self.original_data.keys()) + for mutation in self.records: + if isinstance(mutation, MutationNew): + original_keys.append(mutation.key) + elif isinstance(mutation, MutationDel): + original_keys.remove(mutation.key) + return {key: self.get(key) for key in original_keys} + + @record_mutation + def set(self, key: Any, value: Any) -> Mutation: + is_new = False + if self.is_empty(self.get(key)): + is_new = True + return ( + MutationSet(key, value) if not is_new else MutationNew(key, value) + ) + + @record_mutation + def delete(self, key): + return MutationDel(key) + + def apply(self, mutation: Mutation, write_cache: dict[str, Any]): + if isinstance(mutation, MutationNew): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationSet): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationDel): + write_cache[mutation.key] = MutableData.Empty() + else: + raise ValueError(f"Unknown mutation type {mutation}") + + def reproduce(self, version: int | None = None): + if version is None: + version = self.version + write_cache = self.read_cache.copy() + for mutation in self.records[:version]: + self.apply(mutation, write_cache) + return write_cache + + +class MutableListLikeData(MutableData["list[Any]"]): + def __init__(self, data: Any, getter: DataGetter): + super().__init__(data, getter) + self.read_cache = [ + self.getter(self, idx) for idx in range(len(self.original_data)) + ] + + def clear_read_cache(self): + self.read_cache[:] = [] + + @property + def length(self): + return len(self.reproduce()) + + def get(self, key): + write_cache = self.reproduce(self.version) + return write_cache[key] + + def get_all(self) -> list[Any]: + items = self.reproduce(self.version) + return items + + @record_mutation + def set(self, key: int, value: Any): + return MutationSet(self._regularize_index(key), value) + + @record_mutation + def delete(self, key: int): + return MutationDel(self._regularize_index(key)) + + @record_mutation + def insert(self, index: int, value: Any): + return MutationInsert(self._regularize_index(index), value) + + @record_mutation + def permutate(self, permutation: list[int]): + return MutationPermutate(permutation) + + def _regularize_index(self, index: int): + if index < 0: + index += self.length + return index + + def apply(self, mutation: Mutation, write_cache: list[Any]): + if isinstance(mutation, MutationSet): + write_cache[mutation.key] = mutation.value + elif isinstance(mutation, MutationDel): + write_cache[:] = ( + write_cache[: mutation.key] + write_cache[mutation.key + 1 :] + ) + elif isinstance(mutation, MutationInsert): + write_cache.insert(mutation.index, mutation.value) + elif isinstance(mutation, MutationPermutate): + write_cache[:] = [write_cache[i] for i in mutation.permutation] + else: + raise ValueError(f"Unknown mutation type {mutation}") diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py new file mode 100644 index 00000000000000..6d9ec8829497a5 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -0,0 +1,2070 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dis +import functools +import inspect +import operator +import sys +import traceback +import types +from dataclasses import dataclass +from itertools import chain +from typing import Any, Callable + +import opcode + +from ...profiler import EventGuard, event_register +from ...psdb import NO_BREAKGRAPH_CODES +from ...utils import ( + BreakGraphError, + FallbackError, + InnerError, + OrderedSet, + SotUndefinedVar, + log, + log_do, + min_graph_size, +) +from ..custom_code import CustomCode +from ..instruction_utils import ( + Instruction, + Space, + analysis_inputs, + analysis_used_names_with_space, + calc_stack_effect, + get_instructions, +) +from ..instruction_utils.opcode_info import JumpDirection, PopJumpCond +from .dispatch_functions import ( + operator_BAD, + operator_exception_match, + operator_in, + operator_is_none, + operator_is_not_none, + operator_not_in, +) +from .dispatcher import Dispatcher +from .function_graph import FunctionGraph +from .instr_flag import CALL_FUNCTION_EX_FLAG as CFE +from .instr_flag import FORMAT_VALUE_FLAG as FV +from .instr_flag import MAKE_FUNCTION_FLAG as MF +from .pycode_generator import PyCodeGen +from .tracker import ( + CellTracker, + ConstTracker, + DanglingTracker, + DummyTracker, + LocalTracker, +) +from .variable_stack import VariableStack +from .variables import ( + BuiltinVariable, + CellVariable, + ConstantVariable, + ContainerVariable, + DictVariable, + GlobalVariable, + ListVariable, + MethodVariable, + NullVariable, + SequenceIterVariable, + SliceVariable, + TensorVariable, + TupleVariable, + UserDefinedFunctionVariable, + VariableBase, + VariableFactory, +) + +SUPPORT_COMPARE_OP = { + ">": operator.gt, + "<": operator.lt, + ">=": operator.ge, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + "is not": operator.is_not, + "is": operator.is_, + "in": operator_in, + "not in": operator_not_in, + "exception match": operator_exception_match, + "BAD": operator_BAD, +} + + +@dataclass +class Stop: + state: str + + +def tos_op_wrapper(fn: Callable): + """ + A decorator function that wraps an opcode operation and applies certain functionality to it. + + Args: + fn: The opcode operation to be wrapped. + + Returns: + The wrapped opcode operation. + """ + nargs = len(inspect.signature(fn).parameters) + + @call_break_graph_decorator(push_n=1) + def inner(self: OpcodeExecutorBase, instr: Instruction): + args = self.stack.pop_n(nargs) + res = BuiltinVariable(fn, graph=self._graph, tracker=DanglingTracker())( + *args + ) + self.stack.push(res) + + return inner + + +def tos_inplace_op_wrapper(fn: Callable): + """ + A decorator function that wraps an inplace opcode operation and applies certain functionality to it. + + Args: + fn: The inplace opcode operation to be wrapped. + + Returns: + The wrapped inplace opcode operation. + + """ + + @call_break_graph_decorator(push_n=1) + def inner(self: OpcodeExecutorBase, instr: Instruction): + """ + Inner function that represents the wrapped inplace opcode operation. + + Args: + self: The instance of the OpcodeExecutorBase class. + instr: The instruction to be executed. + + """ + args = self.stack.pop_n(2) + res = BuiltinVariable(fn, graph=self._graph, tracker=DanglingTracker())( + *args + ) + res.debug_name = args[0].debug_name + self.stack.push(res) + + return inner + + +def pop_jump_if_op_wrapper(fns: list[Callable[[Any], Any]]): + """ + A decorator function that wraps a POP_JUMP_*_IF_* opcode operation and applies certain functionality to it. + + Args: + fn: The condition function. + + Returns: + The wrapped POP_JUMP_*_IF_* opcode operation. + + """ + + @jump_break_graph_decorator + def inner(self: OpcodeExecutorBase, instr: Instruction): + """ + Inner function that represents the wrapped POP_JUMP_IF opcode operation. + + Args: + self: The instance of the OpcodeExecutorBase class. + instr: The instruction to be executed. + + """ + pred_obj = self.stack.pop() + + try: + self._graph.add_global_guarded_variable(pred_obj) + res = pred_obj + for fn in fns: + res = BuiltinVariable( + fn, graph=self._graph, tracker=DanglingTracker() + )(res) + + assert isinstance(res, ConstantVariable) + is_jump = res.get_py_value() + assert isinstance(is_jump, bool) + if is_jump: + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + except BreakGraphError: + raise FallbackError( + f"Currently don't support predicate {pred_obj.__class__.__name__}" + ) + + return inner + + +def jump_break_graph_decorator(normal_jump: Callable): + """ + A decorator function that breaks off the graph when a JUMP-related instruction is encountered. + + Args: + normal_jump: The normal jump operation. + + Returns: + The wrapped jump operation. + + """ + + def inner(self: OpcodeExecutor, instr: Instruction): + result = self.stack.top + if isinstance(result, TensorVariable): + self.stack.pop() + # fallback when in OpcodeExecutor + # raise error in OpcodeInlineExecutor + log(3, "[BreakGraph] jump break graph, because if tensor\n") + self._break_graph_in_jump(result, instr) + return Stop(state="BreakGraph") + else: + return normal_jump(self, instr) + + return inner + + +def call_break_graph_decorator(push_n: int | Callable[[int | None], int]): + """ + A decorator function that breaks off the graph when a function CALL instruction is encountered. + + Args: + push_n: The number of arguments to be pushed onto the stack. + + Returns: + The decorated function. + + """ + + def decorate(call_fn: Callable): + @functools.wraps(call_fn) + def wrapper(self: OpcodeExecutor, instr: Instruction): + origin_stack = self.stack.copy() + try: + return call_fn(self, instr) + except BreakGraphError as e: + if self._code in NO_BREAKGRAPH_CODES: + raise InnerError( + f"{self._code.co_name} should not break graph, but got '{e}'" + ) + if isinstance(self, OpcodeExecutor): + log(3, f"[BreakGraph] call function Break graph: {e}\n") + self._break_graph_in_call(origin_stack, instr, push_n) + return Stop(state="BreakGraph") + else: + raise e + + return wrapper + + return decorate + + +def fallback_when_occur_error(fn: Callable): + """ + A decorator function that provides fallback behavior when an error occurs during graph processing. + + Args: + fn: The function to be wrapped. + + Returns: + The wrapped function. + + """ + + def inner(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + raise FallbackError( + f'[Fallback] An exception occurred when processing break graph, fallback to dygraph, error message is: \n{type(e)} : {e}\n' + ) + + return inner + + +class OpcodeExecutorBase: + """ + Base class for executing opcode instructions. + + The OpcodeExecutorBase class provides methods and functionality to execute opcode instructions. + + If you want to learn more about Python instructions, see https://docs.python.org/3/library/dis.html for details. + + Args: + code: The bytecode of the function to be executed. + graph: The function graph. + + Attributes: + call_stack (list[OpcodeExecutorBase]): A list to keep track of the call stack. + _stack (list[VariableBase]): The stack used for storing variables during execution. + _co_consts: List to store constants. + _locals (dict): Dictionary to store local variables. + _globals (dict): Dictionary to store global variables. + _builtins (dict): Dictionary to store built-in variables. + _lasti (int): Index of the last executed instruction. + _code (types.CodeType): The code object to be executed. + _instructions: Iterator of opcode instructions. + _graph (FunctionGraph): The function graph representing the code. + _current_line: The current line number of the execution. + new_code: Placeholder for new code (to be generated by PyCodeGen). + guard_fn: Placeholder for guard function. + _name (str): Name of the executor. + + """ + + call_stack: list[OpcodeExecutorBase] = [] + + @staticmethod + def validate_value(value): + assert isinstance( + value, VariableBase + ), f"value: {value}, type shoule be VariableBase(or derived), but get {type(value)}" + assert not isinstance(value.tracker, DanglingTracker) or isinstance( + value, (NullVariable, CellVariable) + ), f"dangling variable {value} should not be pushed into stack." + + def __init__(self, code: types.CodeType, graph: FunctionGraph): + OpcodeExecutorBase.call_stack.append(self) + # fake env for run, new env should be gened by PyCodeGen + self.stack = VariableStack(validate_value_func=self.validate_value) + self._co_consts = [] + self._locals = {} + self._globals: GlobalVariable = None # type: ignore + self._builtins = {} + self._cells = {} # position to put cells + self._lasti = 0 # idx of instruction list + self._code = code + self._current_line: int = -1 + self._instructions = get_instructions(self._code) + self._graph = graph + self.new_code: types.CodeType | None = None + self.guard_fn = None + self._name = "Executor" + self._call_shape: tuple[ + str, ... + ] | None = None # store kwnames for Python 3.11+ + self._prepare_virtual_env() + + self.stop_state = None + + def print_sir(self): + """ + Prints the Static Instruction Representation (SIR) in the executor. + + """ + print(self._graph.sir_ctx.TOS) + + def _prepare_virtual_env(self): + """ + Prepares the virtual environment for the executor. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError("Please implement virtual_env.") + + def _break_graph_in_jump(self, result, instr: Instruction): + """ + Breaks the graph in JUMP instructions. + + Args: + result: The execution result. + instr: The jump instruction. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError() + + def transform(self): + """ + Abstract method need to be implemented to symbolic translate each instruction. + + Raises: + NotImplementedError: If the method is not implemented. + + """ + raise NotImplementedError() + + def get_var(self, name: str): + """ + Gets the variable with the given name. + + Args: + name: The name of the variable. + + Returns: + The variable. + + Raises: + InnerError: If the variable cannot be found. + + """ + if name in self._locals.keys(): + return self._locals[name] + elif name in self._cells.keys(): # in closure + return self._cells[name].cell_content() + elif name in self._globals.keys(): + return self._globals.get(name) + elif name in self._builtins.keys(): + return self._builtins[name] + else: + raise InnerError(f'Can not get var: {name}') + + def has_var(self, name: str, space: str = "any"): + if space == "any": + return name in set( + chain( + self._locals.keys(), + self._cells.keys(), + self._globals.keys(), + self._builtins.keys(), + ) + ) + elif space == Space.locals: + return name in self._locals + elif space == Space.cells: + return name in self._cells + elif space == Space.globals: + return name in set( + chain( + self._globals.keys(), + self._builtins.keys(), + ) + ) + return False + + def pop_call_stack_until_self(self): + """ + Pops the call stack until the current executor. + + """ + assert ( + self in OpcodeExecutorBase.call_stack + ), f"{self} not in call stack" + while OpcodeExecutorBase.call_stack.pop() is not self: + pass + + @staticmethod + def error_message_summary(original_error: Exception) -> str: + """ + Creates a summary of the error message during execution. + + Args: + original_error: The original error. + + Returns: + The summary error message. + + """ + indent = 2 * " " + message_lines = ["In simulate execution:", ""] + for current_simulator in OpcodeExecutorBase.call_stack: + code = current_simulator._code + current_line = current_simulator._current_line + lines, start = inspect.getsourcelines(code) + real_name = code.co_name + message_lines.append( + f"{indent} File \"{code.co_filename}\", line {current_line}, in {real_name}" + ) + if current_line != -1: + message_lines.append( + f"{indent} {lines[current_line-start].rstrip()}" + ) + error_message = traceback.format_exception_only( + type(original_error), original_error + ) + for line in error_message: + line = line.rstrip() + message_lines.append(f"{indent} {line}") + return "\n".join(message_lines) + + def run(self): + """ + Executes the opcode. + + """ + log(3, f"start execute opcode: {self._code}\n") + self._lasti = 0 + while True: + if self._lasti >= len(self._instructions): + raise InnerError("lasti out of range, InnerError.") + cur_instr = self._instructions[self._lasti] + self._lasti += 1 + is_stop = self.step(cur_instr) + if is_stop: + self.stop_state = is_stop.state + self.pop_call_stack_until_self() + break + + def step(self, instr: Instruction): + """ + Executes a single step of the opcode. + + Args: + instr: The instruction to be executed. + + Returns: + True if execution should stop, False otherwise. + + Raises: + FallbackError: If the opcode is not supported. + + """ + if instr.starts_line is not None: + self._current_line = instr.starts_line + if not hasattr(self, instr.opname): + raise FallbackError(f"opcode: {instr.opname} is not supported.") + log_message = f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self.stack}\n" + log(3, log_message) + code_file = self._code.co_filename + code_line = self._current_line + code_name = self._code.co_name + code_offset = instr.offset + from ..breakpoint import BreakpointManager + + if BreakpointManager().hit( + code_file, code_line, code_name, code_offset + ): + BreakpointManager().locate(self) + print(log_message) + breakpoint() # breakpoint for debug + + with EventGuard(f"{instr.opname}", event_level=1): + return getattr(self, instr.opname)(instr) # run single step. + + def indexof(self, instr: Instruction): + """ + Gets the index of the instruction. + + Args: + instr: The instruction. + + Returns: + The index of the instruction. + + """ + return self._instructions.index(instr) + + def jump_to(self, instr: Instruction): + """ + Jumps to the given instruction. + + Args: + instr: The instruction to jump to. + + """ + self._lasti = self.indexof(instr) + + def COPY(self, instr: Instruction): + assert isinstance(instr.arg, int) + self.stack.push(self.stack.peek[instr.arg]) + + def DUP_TOP(self, instr: Instruction): + self.stack.push(self.stack.top) + + def DUP_TOP_TWO(self, instr: Instruction): + for ref in self.stack.peek[:2]: + self.stack.push(ref) + + def ROT_N(self, instr: Instruction): + assert instr.argval is not None + self._rot_top_n(instr.argval) + + def _rot_top_n(self, n: int): + # a1 a2 a3 ... an <- TOS + # the stack changes to + # an a1 a2 a3 an-1 <- TOS + assert ( + len(self.stack) >= n + ), f"There are not enough elements on the stack. {n} is needed." + top = self.stack.pop() + self.stack.insert(n - 1, top) + + def POP_TOP(self, instr: Instruction): + self.stack.pop() + + def PUSH_NULL(self, instr: Instruction): + self.stack.push(NullVariable()) + + def ROT_TWO(self, instr: Instruction): + self._rot_top_n(2) + + def ROT_THREE(self, instr: Instruction): + self._rot_top_n(3) + + def ROT_FOUR(self, instr: Instruction): + self._rot_top_n(4) + + def RESUME(self, instr: Instruction): + # RESUME is a no-op, it just for internal tracing, debugging and optimization checks. + pass + + def SWAP(self, instr: Instruction): + assert isinstance(instr.arg, int) + self.stack.top, self.stack.peek[instr.arg] = ( + self.stack.peek[instr.arg], + self.stack.top, + ) + + # unary operators + UNARY_POSITIVE = tos_op_wrapper(operator.pos) + UNARY_NEGATIVE = tos_op_wrapper(operator.neg) + UNARY_NOT = tos_op_wrapper(operator.not_) + UNARY_INVERT = tos_op_wrapper(operator.invert) + + # binary operators + BINARY_POWER = tos_op_wrapper(operator.pow) + BINARY_MULTIPLY = tos_op_wrapper(operator.mul) + BINARY_MATRIX_MULTIPLY = tos_op_wrapper(operator.matmul) + BINARY_FLOOR_DIVIDE = tos_op_wrapper(operator.floordiv) + BINARY_TRUE_DIVIDE = tos_op_wrapper(operator.truediv) + BINARY_MODULO = tos_op_wrapper(operator.mod) + BINARY_ADD = tos_op_wrapper(operator.add) + BINARY_SUBTRACT = tos_op_wrapper(operator.sub) + BINARY_LSHIFT = tos_op_wrapper(operator.lshift) + BINARY_RSHIFT = tos_op_wrapper(operator.rshift) + BINARY_AND = tos_op_wrapper(operator.and_) + BINARY_OR = tos_op_wrapper(operator.or_) + BINARY_XOR = tos_op_wrapper(operator.xor) + + def BINARY_OP(self, instr: Instruction): + opname, _ = opcode._nb_ops[instr.arg] + opname = ( + opname.replace("NB_", "BINARY_") + .replace("BINARY_INPLACE", "INPLACE") + .replace("REMAINDER", "MODULO") + ) + return getattr(self, opname)(instr) + + @call_break_graph_decorator(push_n=1) + def BINARY_SUBSCR(self, instr: Instruction): + key = self.stack.pop() + container = self.stack.pop() + assert isinstance(key, VariableBase) + # TODO(xiongkun): getitem / getattr support key and attr as variable. + if isinstance(key, TensorVariable) and isinstance( + container, TensorVariable + ): + # NOTE(xiongkun): tensor[tensor] should support. + output = self._graph.call_tensor_method( + "__getitem__", container, key + ) + self.stack.push(output) + return + + if isinstance(key, TensorVariable): + raise BreakGraphError( + f"Key is a TensorVariable in BINARY_SUBSCR, {container}[{key}]" + ) + + result = BuiltinVariable( + operator.getitem, self._graph, DanglingTracker() + )(container, key) + self.stack.push(result) + + # inplace operators + # paddle variable do not have inplace operators. For example when call `y **= x`, will call var.__pow__ + INPLACE_POWER = tos_inplace_op_wrapper(operator.ipow) + INPLACE_MULTIPLY = tos_inplace_op_wrapper(operator.imul) + INPLACE_MATRIX_MULTIPLY = tos_inplace_op_wrapper(operator.imatmul) + INPLACE_FLOOR_DIVIDE = tos_inplace_op_wrapper(operator.ifloordiv) + INPLACE_TRUE_DIVIDE = tos_inplace_op_wrapper(operator.itruediv) + INPLACE_MODULO = tos_inplace_op_wrapper(operator.imod) + INPLACE_ADD = tos_inplace_op_wrapper(operator.iadd) + INPLACE_SUBTRACT = tos_inplace_op_wrapper(operator.isub) + INPLACE_LSHIFT = tos_inplace_op_wrapper(operator.ilshift) + INPLACE_RSHIFT = tos_inplace_op_wrapper(operator.irshift) + INPLACE_AND = tos_inplace_op_wrapper(operator.iand) + INPLACE_OR = tos_inplace_op_wrapper(operator.ior) + INPLACE_XOR = tos_inplace_op_wrapper(operator.ixor) + + def NOP(self, instr: Instruction): + pass + + @call_break_graph_decorator(push_n=1) + def LOAD_ATTR(self, instr: Instruction): + attr_name = self._code.co_names[instr.arg] + attr_name_var = ConstantVariable.wrap_literal(attr_name, self._graph) + obj = self.stack.pop() + self.stack.push( + BuiltinVariable( + getattr, graph=self._graph, tracker=DanglingTracker() + )(obj, attr_name_var) + ) + + def LOAD_CONST(self, instr: Instruction): + var = self._co_consts[instr.arg] + self.stack.push(var) + + def MAKE_CELL(self, instr: Instruction): + self._locals[instr.argval] = self._cells[instr.argval] + + def LOAD_CLOSURE(self, instr: Instruction): + if sys.version_info >= (3, 11): + self.LOAD_FAST(instr) + return + namemap = self._code.co_cellvars + self._code.co_freevars + name = namemap[instr.arg] + self.stack.push(self._cells[name]) + + def LOAD_DEREF(self, instr: Instruction): + if sys.version_info >= (3, 11): + self.stack.push(self._locals[instr.argval].cell_content()) + return + namemap = self._code.co_cellvars + self._code.co_freevars + name = namemap[instr.arg] + self.stack.push(self._cells[name].cell_content()) + + def COPY_FREE_VARS(self, instr: Instruction): + for i in range(instr.arg): + freevar_name = self._code.co_freevars[i] + self._locals[freevar_name] = self._cells[freevar_name] + + def LOAD_FAST(self, instr: Instruction): + var = self._locals[instr.argval] + self.stack.push(var) + + def DELETE_FAST(self, instr: Instruction): + varname = self._code.co_varnames[instr.arg] + del self._locals[varname] + + def LOAD_GLOBAL(self, instr: Instruction): + namei: int = instr.arg + push_null = False + if sys.version_info >= (3, 11): + push_null = namei & 1 + namei >>= 1 + if push_null: + self.stack.push(NullVariable()) + name = self._code.co_names[namei] + if name in self._globals.keys(): + value = self._globals.get(name) + elif name in self._builtins.keys(): + value = self._builtins[name] + else: + raise InnerError(f"{name} not in globals and builtins") + self.stack.push(value) + + def LOAD_METHOD(self, instr: Instruction): + method_name = self._code.co_names[instr.arg] + method_name_var = ConstantVariable.wrap_literal( + method_name, self._graph + ) + obj = self.stack.pop() + + method = BuiltinVariable( + getattr, graph=self._graph, tracker=DanglingTracker() + )(obj, method_name_var) + + if isinstance(method, MethodVariable): + # bound method, push the unbound method and the self + self.stack.push(method.fn) + self.stack.push(obj) + else: + # unbound method, push the dummy and the function + self.stack.push(NullVariable()) + self.stack.push(method) + + @call_break_graph_decorator(push_n=0) + def STORE_ATTR(self, instr: Instruction): + obj = self.stack.pop() + val = self.stack.pop() + key = self._code.co_names[instr.arg] + key_var = ConstantVariable.wrap_literal(key, self._graph) + BuiltinVariable( + setattr, self._graph, DummyTracker([obj, key_var, val]) + )(obj, key_var, val) + + def DELETE_ATTR(self, instr: Instruction): + obj = self.stack.pop() + key = instr.argval + key_var = ConstantVariable.wrap_literal(key, self._graph) + BuiltinVariable(delattr, self._graph, DummyTracker([obj, key_var]))( + obj, key_var + ) + + def STORE_DEREF(self, instr: Instruction): + if sys.version_info >= (3, 11): + self._cells[instr.argval].set_value(self.stack.pop()) + self._locals[instr.argval] = self._cells[instr.argval] + return + namemap = self._code.co_cellvars + self._code.co_freevars + name = namemap[instr.arg] + self._cells[name].set_value(self.stack.pop()) + + def STORE_FAST(self, instr: Instruction): + """ + TODO: side effect may happen + """ + var = self.stack.pop() + name = self._code.co_varnames[instr.arg] + var.debug_name = name + self._locals[name] = var + + def STORE_GLOBAL(self, instr: Instruction): + var = self.stack.pop() + name = self._code.co_names[instr.arg] + var.debug_name = name + self._globals.set(name, var) + + def DELETE_GLOBAL(self, instr: Instruction): + self._globals.delete(self._code.co_names[instr.arg]) + + @call_break_graph_decorator(push_n=0) + def STORE_SUBSCR(self, instr: Instruction): + key = self.stack.pop() + container = self.stack.pop() + value = self.stack.pop() + assert isinstance(key, VariableBase) + self._graph.add_global_guarded_variable(key) + if isinstance(key, TensorVariable): + raise BreakGraphError( + f"Key is a TensorVariable in STORE_SUBSCR, {container}[{key}] = {value}" + ) + # TODO(xiongkun): support tensor[tensor] = tensor, dy2static is not the same with dygraph. + container[key.get_py_value()] = value + value.debug_name = f"{container.debug_name}[{key.debug_name}]" + + def DELETE_SUBSCR(self, instr: Instruction): + key = self.stack.pop() + container = self.stack.pop() + assert isinstance(key, VariableBase) + self._graph.add_global_guarded_variable(key) + BuiltinVariable(operator.delitem, self._graph, DanglingTracker())( + container, key + ) + + def BUILD_LIST(self, instr: Instruction): + list_size = instr.arg + assert list_size <= len( + self.stack + ), f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." + val_list = self.stack.pop_n(list_size) + self.stack.push( + ListVariable( + val_list, graph=self._graph, tracker=DummyTracker(val_list) + ) + ) + + def BUILD_TUPLE(self, instr: Instruction): + tuple_size = instr.arg + assert tuple_size <= len( + self.stack + ), f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." + val_tuple = self.stack.pop_n(tuple_size) + self.stack.push( + TupleVariable( + tuple(val_tuple), + graph=self._graph, + tracker=DummyTracker(val_tuple), + ) + ) + + def BUILD_STRING(self, instr: Instruction): + count = instr.arg + assert count <= len( + self.stack + ), f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." + str_list = self.stack.pop_n(count) + new_str = '' + for s in str_list: + assert s.get_py_type() is str + new_str += s.get_py_value() + self.stack.push( + ConstantVariable(new_str, self._graph, DummyTracker(str_list)) + ) + + @call_break_graph_decorator(push_n=1) + def BUILD_SLICE(self, instr: Instruction): + if instr.arg == 3: + step = self.stack.pop() + else: + step = ConstantVariable.wrap_literal(None, self._graph) + stop = self.stack.pop() + start = self.stack.pop() + + self.stack.push( + SliceVariable( + slice(start, stop, step), + graph=self._graph, + tracker=DummyTracker([start, stop, step]), + ) + ) + + def build_map( + self, keys: list[VariableBase], values: list[VariableBase] + ) -> VariableBase: + built_map = {} + for key, value in zip(keys, values): + assert isinstance(key, VariableBase) + # Add key to global guarded variable to avoid missing the key guard + self._graph.add_global_guarded_variable(key) + key = key.get_py_value() + built_map[key] = value + return DictVariable( + built_map, + graph=self._graph, + tracker=DummyTracker(keys + values), + ) + + def BUILD_MAP(self, instr: Instruction): + map_size = instr.arg + assert map_size * 2 <= len( + self.stack + ), f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." + val_for_dict = self.stack.pop_n(map_size * 2) + keys = val_for_dict[::2] + values = val_for_dict[1::2] + self.stack.push(self.build_map(keys, values)) + + def BUILD_CONST_KEY_MAP(self, instr: Instruction): + map_size = instr.arg + assert map_size + 1 <= len( + self.stack + ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." + keys = self.stack.pop().get_items() + assert len(keys) == map_size + values = self.stack.pop_n(map_size) + self.stack.push(self.build_map(keys, values)) + + def build_seq_unpack(self, instr: Instruction): + oparg = instr.arg + assert isinstance(oparg, int) + unpack_values = self.stack.pop_n(oparg) + + retval = [] + for item in unpack_values: + assert isinstance(item, (TupleVariable, ListVariable)) + retval.extend(item.get_wrapped_items()) + + if instr.opname in { + "BUILD_TUPLE_UNPACK_WITH_CALL", + "BUILD_TUPLE_UNPACK", + }: + retval = tuple(retval) + + self.stack.push( + VariableFactory.from_value( + retval, self._graph, DummyTracker(unpack_values) + ) + ) + + def BUILD_TUPLE_UNPACK_WITH_CALL(self, instr: Instruction): + self.build_seq_unpack(instr) + + def BUILD_TUPLE_UNPACK(self, instr: Instruction): + self.build_seq_unpack(instr) + + def BUILD_LIST_UNPACK(self, instr: Instruction): + self.build_seq_unpack(instr) + + def BUILD_MAP_UNPACK(self, instr: Instruction): + oparg = instr.arg + assert isinstance(oparg, int) + unpack_values = self.stack.pop_n(oparg) + + retval = {} + for item in unpack_values: + assert item.get_py_type() is dict + retval.update(item.get_wrapped_items()) + + self.stack.push( + VariableFactory.from_value( + retval, self._graph, DummyTracker(unpack_values) + ) + ) + + def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): + oparg = instr.arg + assert isinstance(oparg, int) + unpack_values = self.stack.pop_n(oparg) + + retval = {} + for item in unpack_values: + assert item.get_py_type() is dict + wrapped_item = item.get_wrapped_items() + if wrapped_item.items() & retval.items(): + raise InnerError( + "BUILD_MAP_UNPACK_WITH_CALL found repeated key." + ) + retval.update(wrapped_item) + + self.stack.push( + VariableFactory.from_value( + retval, self._graph, DummyTracker(unpack_values) + ) + ) + + def PRECALL(self, instr: Instruction): + assert isinstance(instr.arg, int) + is_method_layout = not isinstance( + self.stack.peek[instr.arg + 2], NullVariable + ) + nargs = instr.arg + int(is_method_layout) + method = self.stack.peek[nargs + 1] + if not is_method_layout and isinstance(method, MethodVariable): + unbound_method = method.fn + self_var = method.bound_instance + self.stack.peek[nargs + 1] = self_var + self.stack.peek[nargs + 2] = unbound_method + + def KW_NAMES(self, instr: Instruction): + assert self._call_shape is None + assert isinstance(instr.arg, int) + self._call_shape = self._co_consts[instr.arg].get_py_value() + + @call_break_graph_decorator(push_n=1) + def CALL(self, instr: Instruction): + assert isinstance(instr.arg, int) + assert instr.arg + 2 <= len(self.stack) + is_method = not isinstance(self.stack.peek[instr.arg + 2], NullVariable) + total_args = instr.arg + int(is_method) + kwnames = self._call_shape if self._call_shape is not None else [] + n_kwargs = len(kwnames) + n_positional_args = total_args - n_kwargs + kwargs_list = self.stack.pop_n(n_kwargs) + kwargs = dict(zip(kwnames, kwargs_list)) + args = self.stack.pop_n(n_positional_args) + fn = self.stack.pop() + if not is_method: + # pop the NULL variable + self.stack.pop() + self.stack.push(fn(*args, **kwargs)) + self._call_shape = None + + @call_break_graph_decorator(push_n=1) + def CALL_FUNCTION(self, instr: Instruction): + assert isinstance(instr.arg, int) + n_args = instr.arg + assert isinstance(n_args, int) + args = self.stack.pop_n(n_args) + kwargs = {} + fn = self.stack.pop() + ret = fn(*args, **kwargs) + self.stack.push(ret) + + @call_break_graph_decorator(push_n=1) + def CALL_FUNCTION_KW(self, instr: Instruction): + n_args = instr.arg + assert n_args + 2 <= len(self.stack) + + kwargs_keys = self.stack.pop() + assert isinstance(kwargs_keys, TupleVariable) + assert len(kwargs_keys) > 0 + kwargs_keys = [ + x.get_py_value() if isinstance(x, VariableBase) else x + for x in kwargs_keys.get_py_value() + ] + + # split arg_list to args and kwargs + arg_list = self.stack.pop_n(n_args) + args = arg_list[: -len(kwargs_keys)] + kwargs_values = arg_list[-len(kwargs_keys) :] + kwargs = dict(zip(kwargs_keys, kwargs_values)) + + fn = self.stack.pop() + ret = fn(*args, **kwargs) + self.stack.push(ret) + + @call_break_graph_decorator(push_n=1) + def CALL_FUNCTION_EX(self, instr: Instruction): + flag = instr.arg + if flag & CFE.CFE_HAS_KWARGS: + kwargs_variable = self.stack.pop() + assert isinstance(kwargs_variable, DictVariable) + kwargs = kwargs_variable.get_wrapped_items() + else: + kwargs = {} + + args_variable = self.stack.pop() + assert isinstance(args_variable, (TupleVariable, ListVariable)) + args = args_variable.get_wrapped_items() + + fn = self.stack.pop() + if sys.version_info >= (3, 11): + null = self.stack.pop() + assert isinstance(null, NullVariable) + ret = fn(*args, **kwargs) + self.stack.push(ret) + + @call_break_graph_decorator(push_n=1) + def CALL_METHOD(self, instr: Instruction): + n_args = instr.arg + assert isinstance(n_args, int) + args = self.stack.pop_n(n_args) + self_var = self.stack.pop() + method = self.stack.pop() + if isinstance(method, NullVariable): + method = self_var + else: + args = [self_var] + args + self.stack.push(method(*args)) + + @call_break_graph_decorator( + push_n=1 + ) # call instance, in, not in may call TensorVariable.get_py_value, which raise BreakGraphError + def COMPARE_OP(self, instr: Instruction): + op = dis.cmp_op[instr.arg] + right, left = self.stack.pop(), self.stack.pop() + self.stack.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + + @call_break_graph_decorator(push_n=1) + def IS_OP(self, instr: Instruction): + # It will only be 0 or 1 + assert instr.arg == 0 or instr.arg == 1 + right, left = self.stack.pop(), self.stack.pop() + op = "is" if instr.arg == 0 else "is not" + self.stack.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + + def MAKE_FUNCTION(self, instr: Instruction): + if sys.version_info < (3, 11): + fn_name = self.stack.pop() + codeobj = self.stack.pop() + if sys.version_info >= (3, 11): + # MAKE_FUNCTION behavior actually changed in 3.11, see + # https://github.com/python/cpython/pull/93189/ + assert hasattr(codeobj.value, "co_qualname") + fn_name = ConstantVariable( + codeobj.value.co_qualname, self._graph, DummyTracker([codeobj]) + ) + + global_dict = self._globals.get_value() + + related_list = [fn_name, codeobj] + + flag = instr.arg + if flag & MF.MF_HAS_CLOSURE: + # closure should be a tuple of Variables + closure_variable = self.stack.pop() + assert isinstance(closure_variable, TupleVariable) + closure = [] + for item in closure_variable.get_wrapped_items(): + closure.append(types.CellType()) + closure[-1].cell_contents = item + closure = tuple(closure) + else: + closure = () + + if flag & MF.MF_HAS_ANNOTATION: + # can not set annotation in python env, skip it + related_list.append(self.stack.pop()) + + if flag & MF.MF_HAS_KWDEFAULTS: + raise FallbackError( + "Found need func_kwdefaults when MAKE_FUNCTION." + ) + + if flag & MF.MF_HAS_DEFAULTS: + ''' + default_args should have tracker too, like: + + def f(x): + def g(z=x): + pass + ''' + default_args_variable = self.stack.pop() + assert isinstance(default_args_variable, TupleVariable) + related_list.append(default_args_variable) + default_args = tuple(default_args_variable.get_wrapped_items()) + else: + default_args = () + + new_fn = types.FunctionType( + codeobj.get_py_value(), + global_dict, + fn_name.get_py_value(), + default_args, + closure, + ) + self.stack.push( + UserDefinedFunctionVariable( + new_fn, self._graph, DummyTracker(related_list) + ) + ) + + def GET_ITER(self, instr: Instruction): + source_obj = self.stack.pop() + iter_variable = BuiltinVariable(iter, self._graph, DanglingTracker())( + source_obj + ) + self.stack.push(iter_variable) + + def JUMP_ABSOLUTE(self, instr: Instruction): + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + + def JUMP_FORWARD(self, instr: Instruction): + self.JUMP_ABSOLUTE(instr) + + def JUMP_BACKWARD(self, instr: Instruction): + # TODO: check interrupt + self.JUMP_ABSOLUTE(instr) + + def JUMP_BACKWARD_NO_INTERRUPT(self, instr: Instruction): + self.JUMP_ABSOLUTE(instr) + + @call_break_graph_decorator(push_n=1) + def CONTAINS_OP(self, instr: Instruction): + # It will only be 0 or 1 + assert instr.arg == 0 or instr.arg == 1 + right, left = self.stack.pop(), self.stack.pop() + op = "in" if instr.arg == 0 else "not in" + self.stack.push( + BuiltinVariable( + SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() + )(left, right) + ) + + @jump_break_graph_decorator + def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): + pred_obj = self.stack.top + if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): + self._graph.add_global_guarded_variable(pred_obj) + is_jump = not bool(pred_obj) + if is_jump: + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + else: + self.stack.pop() + return + raise FallbackError( + "Currently don't support predicate a non-const / non-tensor obj." + ) + + @jump_break_graph_decorator + def JUMP_IF_TRUE_OR_POP(self, instr: Instruction): + pred_obj = self.stack.top + if isinstance(pred_obj, (ConstantVariable, ContainerVariable)): + self._graph.add_global_guarded_variable(pred_obj) + is_jump = bool(pred_obj) + if is_jump: + assert instr.jump_to is not None + self.jump_to(instr.jump_to) + else: + self.stack.pop() + return + raise FallbackError( + "Currently don't support predicate a non-const / non-tensor obj." + ) + + POP_JUMP_IF_FALSE = pop_jump_if_op_wrapper([bool, operator.not_]) + POP_JUMP_FORWARD_IF_FALSE = POP_JUMP_IF_FALSE + POP_JUMP_BACKWARD_IF_FALSE = POP_JUMP_IF_FALSE + + POP_JUMP_IF_TRUE = pop_jump_if_op_wrapper([bool]) + POP_JUMP_FORWARD_IF_TRUE = POP_JUMP_IF_TRUE + POP_JUMP_BACKWARD_IF_TRUE = POP_JUMP_IF_TRUE + + POP_JUMP_FORWARD_IF_NONE = pop_jump_if_op_wrapper([operator_is_none]) + POP_JUMP_BACKWARD_IF_NONE = POP_JUMP_FORWARD_IF_NONE + + POP_JUMP_FORWARD_IF_NOT_NONE = pop_jump_if_op_wrapper( + [operator_is_not_none] + ) + POP_JUMP_BACKWARD_IF_NOT_NONE = POP_JUMP_FORWARD_IF_NOT_NONE + + @call_break_graph_decorator(push_n=lambda arg: arg) + def UNPACK_SEQUENCE(self, instr: Instruction): + sequence = self.stack.pop() + seq_iter = BuiltinVariable(iter, self._graph, DanglingTracker())( + sequence + ) + unpacked = [] + for _ in range(instr.arg): + unpacked.append(seq_iter.next()) + for item in reversed(unpacked): + self.stack.push(item) + + def UNPACK_EX(self, instr: Instruction): + getitem = BuiltinVariable( + operator.getitem, self._graph, DanglingTracker() + ) + assert instr.arg is not None + sequence = self.stack.pop() + if not isinstance( + sequence, (ListVariable, TupleVariable, TensorVariable) + ): + raise FallbackError(f"Unpack {sequence} is not implemented.") + + if instr.argval >= 256: + # NOTE: If the number of unpacked variables exceeds 256, python will report an error like: + # SyntaxError: too many expressions in star-unpacking assignmen, + # so if the number of unpacked variables exceeds 256, it will be treated as the following case. + # a, b, *c, d = e + front_nums = instr.arg & 0xFF + back_nums = instr.arg >> 8 + assert ( + len(sequence) >= front_nums + back_nums + ), f"Want unpack {sequence} to {front_nums + back_nums}, but {len(sequence)} is smaller than {front_nums + back_nums}." + + for i in range( + len(sequence) - 1, len(sequence) - back_nums - 1, -1 + ): + self.stack.push(getitem(sequence, i)) + + slice_var = SliceVariable( + slice(front_nums, len(sequence) - back_nums - 1), + self._graph, + DummyTracker([sequence]), + ) + else: + # a, b, c, *d = e + assert ( + len(sequence) >= instr.arg + ), f"Want unpack {sequence} to {instr.arg}, but {len(sequence)} is smaller than {instr.arg}." + + slice_obj = slice(instr.arg, None) + slice_var = SliceVariable( + slice_obj, self._graph, ConstTracker(slice_obj) + ) + front_nums = instr.arg + self.stack.push(getitem(sequence, slice_var)) + for i in range(front_nums - 1, -1, -1): + self.stack.push(getitem(sequence, i)) + + def FORMAT_VALUE(self, instr: Instruction): + flag = instr.arg + assert flag is not None + which_conversion = flag & FV.FVC_MASK + have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) + + fmt_spec = self.stack.pop().get_py_value() if have_fmt_spec else "" + value = self.stack.pop() + + if which_conversion == FV.FVC_NONE: + convert_fn = None + elif which_conversion == FV.FVC_STR: + convert_fn = "__str__" + elif which_conversion == FV.FVC_REPR: + convert_fn = "__repr__" + elif which_conversion == FV.FVC_ASCII: + convert_fn = "__ascii__" + else: + raise InnerError( + f"Unexpected conversion flag {flag} for FORMAT_VALUE" + ) + + # different type will lead to different Tracker, so call self.stack.push in different branch + if isinstance(value, ConstantVariable): + result = value.get_py_value() + if convert_fn is not None: + result = getattr(result, convert_fn)(result) + + if not isinstance(result, str) or fmt_spec != "": + result = format(result, fmt_spec) + + self.stack.push( + ConstantVariable(result, self._graph, DummyTracker([value])) + ) + else: + raise FallbackError(f"Do not support format {type(value)} now") + + # NOTE: This operation will generate SideEffects, and the mechanism has not been completed yet + def DICT_UPDATE(self, instr: Instruction): + dict_value = self.stack.pop() + assert isinstance(instr.arg, int) + BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], dict_value + ) + + def DICT_MERGE(self, instr: Instruction): + dict_value = self.stack.pop() + assert isinstance(instr.arg, int) + for key in dict_value.get_wrapped_items().keys(): + result = ( + self.stack.peek[instr.arg].get_wrapped_items().get(key, None) + ) + if result is not None: + raise InnerError( + f"got multiple values for keyword argument '{key}'" + ) + BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], dict_value + ) + + def LIST_APPEND(self, instr: Instruction): + list_value = self.stack.pop() + assert isinstance(instr.arg, int) + BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], list_value + ) + + def MAP_ADD(self, instr: Instruction): + key, value = self.stack.pop_n(2) + assert isinstance(instr.arg, int) + BuiltinVariable(operator.setitem, self._graph, DanglingTracker())( + self.stack.peek[instr.arg], key, value + ) + + def LIST_EXTEND(self, instr: Instruction): + list_value = self.stack.pop() + assert isinstance(instr.arg, int) + BuiltinVariable(list.extend, self._graph, tracker=DanglingTracker())( + self.stack.peek[instr.arg], list_value + ) + + def LIST_TO_TUPLE(self, instr: Instruction): + list_value = self.stack.pop() + self.stack.push( + TupleVariable( + list_value.get_wrapped_items(), + self._graph, + DummyTracker([list_value]), + ) + ) + + +class OpcodeExecutor(OpcodeExecutorBase): + """ + A class that represents an executor for opcode operations. + + Args: + frame: The frame object. + + """ + + def __init__(self, frame: types.FrameType, **kwargs): + graph = FunctionGraph(frame, **kwargs) + self._frame = frame + self._name = "Executor" + self.call_stack[:] = [] + super().__init__(frame.f_code, graph) + Dispatcher.graph = graph + + def cleanup(self): + self._graph.pycode_gen = None + Dispatcher.graph = None + + @event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2) + def _prepare_virtual_env(self): + """ + Prepare the virtual environment for execution by adding variables from locals, globals, builtins, and constants. + + """ + log( + 3, + f"[Executor] code options: co_cellvars={self._frame.f_code.co_cellvars}\n", + ) + free_or_cell_vars = ( + self._frame.f_code.co_cellvars + self._frame.f_code.co_freevars + ) + for name, value in self._frame.f_locals.items(): + tracker = ( + CellTracker(name) + if name in free_or_cell_vars + else LocalTracker(name) + ) + self._locals[name] = VariableFactory.from_value( + value, self._graph, tracker, debug_name=name + ) + + for name in free_or_cell_vars: + # create a cell for each variable. + self._cells[name] = CellVariable() # put in cells. + if name in self._locals: + self._cells[name].set_value(self._locals[name]) + + self._globals = GlobalVariable( + self._frame.f_globals, + self._graph, + DanglingTracker(), + ) + + self._builtins = self._graph._builtins + + for value in self._code.co_consts: + self._co_consts.append( + VariableFactory.from_value( + value, self._graph, ConstTracker(value) + ) + ) + + def _create_resume_fn(self, index, stack_size=0): + """ + Create a resume function and its inputs at the specified index. + + Args: + index: The index at which the resume function is created. + stack_size: The size of the stack. + + Returns: + The resume function and its inputs. + + """ + pycode_gen = PyCodeGen(self._frame) + fn, inputs = pycode_gen.gen_resume_fn_at(index, stack_size) + return fn, inputs + + @fallback_when_occur_error + def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): + """ + Break the graph at a JUMP instruction. + + Args: + result: The result variable of the jump instruction. + instr: The jump instruction. + + """ + self._graph.add_global_guarded_variable(result) + stack_size = len(self.stack) + if_fn, if_inputs = self._create_resume_fn( + self.indexof(instr) + 1, stack_size + ) + else_fn, else_inputs = self._create_resume_fn( + self.indexof(instr.jump_to), stack_size + ) + + # gen call static fn opcode + inputs_name = if_inputs | else_inputs + inputs_var = [ + self.get_var(name) + for name in inputs_name + if self.get_var(name) is not result + ] + ret_vars = [ + result, + ] + inputs_var + # Collect all the to store variables. + store_vars = [] + for stack_arg in self.stack: + store_vars.append(stack_arg) + for name in inputs_name: + store_vars.append(self.get_var(name)) + + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + # only pop the input of if/else resume fn, and keep the bool tensor result on the stack + for _ in inputs_var: + self._graph.pycode_gen.gen_pop_top() + + # gen call if/else resume fn opcode + if if_fn is not None: + self._graph.pycode_gen.gen_load_object( + if_fn, if_fn.__code__.co_name + ) + insert_index = len(self._graph.pycode_gen._instructions) - 1 + for i, stack_arg in enumerate(self.stack): + var_loader.load( + stack_arg, allow_push_null=i >= len(self.stack) - 1 + ) + for name in if_inputs: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_call_function( + argc=if_fn.__code__.co_argcount, + ) + self._graph.pycode_gen.gen_return() + else: + insert_index = len(self._graph.pycode_gen._instructions) - 1 + self._graph.pycode_gen.gen_return() + + if else_fn is not None: + self._graph.pycode_gen.gen_load_object( + else_fn, else_fn.__code__.co_name + ) + jump_to = self._graph.pycode_gen._instructions[-1] + for i, stack_arg in enumerate(self.stack): + var_loader.load( + stack_arg, allow_push_null=i >= len(self.stack) - 1 + ) + for name in else_inputs: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_call_function( + argc=else_fn.__code__.co_argcount, + ) + self._graph.pycode_gen.gen_return() + else: + self._graph.pycode_gen.gen_return() + jump_to = self._graph.pycode_gen._instructions[-1] + + # gen jump opcode + self._graph.pycode_gen._insert_instr( + insert_index, instr.opname, jump_to=jump_to + ) + + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + + @fallback_when_occur_error + def _break_graph_in_call( + self, + origin_stack: VariableStack, + instr: Instruction, + push_n: int | Callable[[int | None], int], + ): + """ + Break the graph at a CALL instruction. + + Args: + origin_stack: The original stack. + instr: The call instruction. + push_n: The number of elements to be pushed onto the stack. + + """ + push_n = push_n(instr.arg) if callable(push_n) else push_n + index = self.indexof(instr) + self.stack = origin_stack + + # gen call static fn opcode + ret_vars = [ + arg + for arg in self.stack + if isinstance(arg, (TensorVariable, ContainerVariable)) + ] + resume_input_name = analysis_inputs(self._instructions, index + 1) + ret_vars = ret_vars + [ + self.get_var(name) + for name in resume_input_name + if self.get_var(name) not in ret_vars + ] + + # Collect all the to store variables. + store_vars = [] + for stack_arg in self.stack: + store_vars.append(stack_arg) + for name in resume_input_name: + store_vars.append(self.get_var(name)) + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + + for _ in ret_vars: + self._graph.pycode_gen.gen_pop_top() + + # gen graph break call fn opcode + stack_effect = calc_stack_effect(instr) + pop_n = push_n - stack_effect + + for i, stack_arg in enumerate(self.stack): + var_loader.load( + stack_arg, allow_push_null=i >= len(self.stack) - pop_n + ) + + # gen call resume fn opcode + # NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None. + self._graph.pycode_gen.gen_kw_names(self._call_shape) + self._graph.pycode_gen.add_pure_instructions([instr]) + self.stack.pop_n(pop_n) + stack_size = len(self.stack) + push_n + + resume_fn, _ = self._create_resume_fn(index + 1, stack_size) + if resume_fn: + self._graph.pycode_gen.gen_load_object( + resume_fn, resume_fn.__code__.co_name + ) + # NOTE(zrr1999): We need to shift the resume_fn under its arguments. + # In Python 3.11+, NULL + resume_fn should be shifted together. + shift_n = 2 if sys.version_info >= (3, 11) else 1 + self._graph.pycode_gen.gen_shift_n(shift_n, stack_size + shift_n) + for name in resume_input_name: + var_loader.load(self.get_var(name)) + self._graph.pycode_gen.gen_call_function( + argc=resume_fn.__code__.co_argcount, + ) + + # gen RETURN_VALUE + self._graph.pycode_gen.gen_return() + + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + + def transform(self): + self.run() + if self.new_code is None: + raise InnerError("OpExecutor return a empty new_code.") + # stopped by RETURN_VALUE and has sir len is enough => disable_eval_frame + simulate_complete = bool(self.stop_state == "Return") + if simulate_complete: + if self._graph.sir_ctx.TOS.graph_size() < min_graph_size(): + raise FallbackError( + "Fallback after simulate for reasons.", + disable_eval_frame=True, + ) + else: + # if simulate stop with graph successfully, the all codes will be + # surrounded by the eval_frame triggers which exist in self.new_code + # we need not set disable_eval_frame=False here (for it already is) + return ( + CustomCode(self.new_code, True), + self.guard_fn, + ) + else: + # if return because breakgraph, need open eval_frame + return ( + CustomCode(self.new_code, False), + self.guard_fn, + ) + + def _gen_loop_body_between( + self, inputs: list, for_iter_idx: int, start: int, end: int + ) -> types.FunctionType: + """ + Generates the loop body between the specified indices in the instruction list. + + Args: + inputs: function inputs infos + for_iter_idx (int): For find the for_iter opcode + start (int): The start index of the loop body. + end (int): The end index of the loop body. + + Returns: + tuple: The generated loop body function object and its inputs. + + """ + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + + for_iter = origin_instrs[for_iter_idx] + + # for balance the stack (the loop body will pop iter first before break or return) + # this None is used for replace the iterator obj in stack top + pycode_gen.gen_load_const(None) + + # extend loop body main logic + pycode_gen.extend_instrs(origin_instrs[start:end]) + + # break should jump to this nop + nop_for_break = pycode_gen._add_instr("NOP") + + # need do additional operates when break + pycode_gen.gen_load_const(False) + pycode_gen.gen_store_fast(inputs[-1]) + pycode_gen.gen_load_const(None) # keep stack balance + + # continue should jump to this nop + nop_for_continue = pycode_gen._add_instr("NOP") + pycode_gen.gen_pop_top() + + # relocate jump + out_loop = for_iter.jump_to + for instr in pycode_gen._instructions: + if instr.jump_to == for_iter: + instr.jump_to = nop_for_continue + if instr.jump_to == out_loop: + instr.jump_to = nop_for_break + + # outputs is the same as inputs + pycode_gen.gen_outputs_and_return(inputs) + return pycode_gen.create_fn_with_inputs(inputs) + + @fallback_when_occur_error + def _break_graph_in_for_loop( + self, iterator: VariableBase, for_iter: Instruction + ): + ''' + for_iter: the FOR_ITER opcode + + need find out opcodes which unpack value from FOR_ITER, by analysing stack + + case 1: + for i in iter: + + FOR_ITER + STORE_FAST i + + case 2: + for i,j in iter: + + FOR_ITER + UNPACK_SEQUENCE 2 + STORE_FAST i + STORE_FAST j + + TODO: check var is in globals or builtins, only locals considered now + ''' + # 0. prepare sub functions + # 0.1 find the range of loop body + assert for_iter.jump_to is not None + loop_body_start_idx = self.indexof(for_iter) + 1 + loop_body_end_idx = self.indexof(for_iter.jump_to) + curent_stack = 1 + + while True: + if loop_body_start_idx >= len(self._instructions): + raise InnerError("Can not balance stack in loop body.") + cur_instr = self._instructions[loop_body_start_idx] + # do not consider jump instr + stack_effect = calc_stack_effect(cur_instr, jump=False) + curent_stack += stack_effect + loop_body_start_idx += 1 + if curent_stack == 0: + break + + # 0.2 create loop body function + all_used_vars = analysis_used_names_with_space( + self._instructions, loop_body_start_idx, loop_body_end_idx + ) + loop_body_inputs = [ + k + for k, v in all_used_vars.items() + if v in (Space.locals, Space.cells) + ] + ["_break_flag"] + + loop_body_fn = self._gen_loop_body_between( + loop_body_inputs, + self.indexof(for_iter), + loop_body_start_idx, + loop_body_end_idx, + ) + + log(3, "[Resumed Function]: break graph in loop create loop body as\n") + log_do(3, lambda: dis.dis(loop_body_fn)) + + # 0.3 create after loop part function + after_loop_fn, fn_inputs = self._create_resume_fn( + loop_body_end_idx, len(self.stack) + ) + + total_inputs = OrderedSet(list(fn_inputs) + list(loop_body_inputs[:-1])) + + # 1. part before for-loop, start compile + ret_names = [ + name + for name in total_inputs + if name in chain(self._locals, self._cells) + ] + ret_vars = [self.get_var(name) for name in ret_names] + store_vars = [ret_vars[idx] for idx in range(len(ret_names))] + store_vars.extend(iter(self.stack)) + store_vars.append(iterator.get_hold()) + var_loader = self._graph.start_compile_with_name_store( + ret_vars, store_vars + ) + + for _ in ret_vars: + self._graph.pycode_gen.gen_pop_top() + + # 2. restore vars + for idx in range(len(ret_names)): + var_loader.load(ret_vars[idx]) + self._graph.pycode_gen.gen_store(ret_names[idx], self._code) + + # 3. setup vars which is created in loop + undefined_names = set() + for name in loop_body_inputs[:-1]: + if not self.has_var(name, all_used_vars[name]): + undefined_names.add(name) + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) + + # close eval_frame + # TODO: need support effective strategies + # self._graph.pycode_gen.gen_disable_eval_frame() + + # 4.1 load iterator + iterator.reconstruct(self._graph.pycode_gen) + + # 4.2 gen FOR_ITER and unpack data + self._graph.pycode_gen.extend_instrs( + self._instructions[self.indexof(for_iter) : loop_body_start_idx] + ) + + # 5. call loop body + # 5.1 load loop body + self._graph.pycode_gen.gen_load_object( + loop_body_fn, loop_body_fn.__code__.co_name + ) + + # 5.2 load loop body inputs + for name in loop_body_inputs[:-1]: + self._graph.pycode_gen.gen_load(name) + + # 5.3 load break flag + self._graph.pycode_gen.gen_load_const(True) + + # 5.4 call loop body + self._graph.pycode_gen.gen_call_function( + argc=loop_body_fn.__code__.co_argcount + ) + + # 5.5 unpack and store retval, keep break_flag in stack + self._graph.pycode_gen.gen_unpack_sequence(len(loop_body_inputs)) + + for name in loop_body_inputs[:-1]: + self._graph.pycode_gen.gen_store(name, self._code) + + # 6. add jump if break + jump_if_break = self._graph.pycode_gen.gen_pop_jump( + direction=JumpDirection.FORWARD, suffix=PopJumpCond.FALSE + ) + + # 7. jump back to FOR_ITER + self._graph.pycode_gen.gen_jump( + for_iter, direction=JumpDirection.BACKWARD + ) + nop = self._graph.pycode_gen._add_instr("NOP") + for_iter.jump_to = nop + jump_if_break.jump_to = nop + + # open eval_frame + # TODO: need support effective strategies + # self._graph.pycode_gen.gen_enable_eval_frame() + + # 8. call after_loop_fn + self._graph.pycode_gen.gen_load_object( + after_loop_fn, after_loop_fn.__code__.co_name + ) + + for stack_arg in self.stack: + var_loader.load(stack_arg) + for name in fn_inputs: + if not self.has_var(name) and name not in undefined_names: + undefined_names.add(name) + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) + self._graph.pycode_gen.gen_load(name) + + self._graph.pycode_gen.gen_call_function( + argc=after_loop_fn.__code__.co_argcount + ) + + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + + def _inline_call_for_loop( + self, iterator: VariableBase, for_iter: Instruction + ): + assert for_iter.jump_to is not None + pycode_gen = PyCodeGen(self._frame) + origin_instrs = get_instructions(pycode_gen._origin_code) + + start_idx = self.indexof(for_iter) + end_idx = self.indexof(for_iter.jump_to) + + all_used_vars = analysis_used_names_with_space( + origin_instrs, start_idx, end_idx + ) + + inputs = [ + k + for k, v in all_used_vars.items() + if v in (Space.locals, Space.cells) + ] + [iterator.id] + + # 1. load iter + pycode_gen.gen_load_fast(iterator.id) + + # 2. copy main logic + pycode_gen.extend_instrs(origin_instrs[start_idx:end_idx]) + + # 3. add break, continue marker and relocate jump + for_iter_instr = origin_instrs[start_idx] + assert for_iter_instr.jump_to is not None + out_loop_instr = for_iter_instr.jump_to + + pycode_gen.gen_jump(out_loop_instr, direction=JumpDirection.FORWARD) + nop_for_continue = pycode_gen._add_instr("NOP") + + jump = pycode_gen.gen_jump( + for_iter_instr, direction=JumpDirection.BACKWARD + ) + + nop_for_break = pycode_gen._add_instr("NOP") + + for instr in pycode_gen._instructions: + if instr.jump_to == for_iter_instr: + instr.jump_to = nop_for_continue + + if ( + instr.jump_to in origin_instrs + and origin_instrs.index(instr.jump_to) >= end_idx + ): + instr.jump_to = nop_for_break + + jump.jump_to = for_iter_instr + pycode_gen.gen_outputs_and_return(inputs) + inline_call_fn = pycode_gen.create_fn_with_inputs(inputs) + + log( + 3, + f"[Resumed Function]: Inline call for loop function {inline_call_fn.__code__.co_name}\n", + ) + log_do(3, lambda: dis.dis(inline_call_fn)) + + # TODO: update globals builtins + fn = UserDefinedFunctionVariable( + inline_call_fn, + self._graph, + DanglingTracker(), + ) + + input_vars = [ + self.get_var(name) + if self.has_var(name, all_used_vars[name]) + else SotUndefinedVar() + for name in inputs[:-1] + ] + [iterator] + ret = fn(*input_vars) + # slice_variable is [:-1] + slice_const = slice(None, -1, None) + slice_variable = SliceVariable( + slice_const, self._graph, ConstTracker(slice_const) + ) + for name, val in zip(inputs[:-1], ret[slice_variable]): + self._locals[name] = val + + def FOR_ITER(self, instr): + iterator = self.stack.pop() + backup_iter_idx = None + + start = self.indexof(instr) + end = self.indexof(instr.jump_to) + for i in range(start, end): + if self._instructions[i].opname == "RETURN_VALUE": + raise FallbackError("Found RETURN_VALUE in for loop body.") + + self._graph.add_global_guarded_variable(iterator) + + try: + if not isinstance(iterator, SequenceIterVariable): + raise BreakGraphError() + + backup_iter_idx = iterator.idx + + self._inline_call_for_loop(iterator, instr) + self._lasti = self.indexof(instr.jump_to) + except BreakGraphError as e: + log(3, f"{e}") + if backup_iter_idx: + iterator.idx = backup_iter_idx + self._graph.remove_global_guarded_variable(iterator) + self._break_graph_in_for_loop(iterator, instr) + return Stop(state="BreakGraph") + + def RETURN_VALUE(self, instr: Instruction): + assert ( + len(self.stack) == 1 + ), f"Stack must have one element, but get {len(self.stack)} elements." + ret_val = self.stack.pop() + self._graph.start_compile(ret_val) + self._graph.pycode_gen.gen_return() + self.new_code = self._graph.pycode_gen.gen_pycode() + self.guard_fn = self._graph.guard_fn + return Stop(state="Return") diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py new file mode 100644 index 00000000000000..c24e94b07ffb26 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -0,0 +1,330 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import inspect +import re +from typing import TYPE_CHECKING + +from ...profiler import event_register +from ...utils import BreakGraphError, log +from ..instruction_utils import Instruction +from .guard import StringifyExpression, union_free_vars +from .opcode_executor import OpcodeExecutorBase, Stop +from .tracker import ConstTracker, DanglingTracker, DummyTracker, Tracker +from .variables import ( + CellVariable, + FunctionGlobalVariable, + IterVariable, + SequenceIterVariable, + VariableBase, +) + +if TYPE_CHECKING: + from .pycode_generator import PyCodeGen + from .variables import FunctionVariable + + +class FunctionGlobalTracker(Tracker): + """ + A tracker class that represents a function global variable. + + Args: + fn: FunctionVariable object. + name: The name of the global variable. + + """ + + def __init__(self, fn: FunctionVariable, name: str): + super().__init__([fn]) + self.fn = fn + self.name = name + + def gen_instructions(self, codegen: PyCodeGen): + """ + Generate bytecode instructions in order to put the variables at the top of the stack. + + Args: + codegen: The PyCodeGen object used to generate bytecode. + + """ + self.fn.tracker.gen_instructions(codegen) + codegen.gen_load_attr("__globals__") + codegen.gen_load_const(self.name) + codegen.gen_subscribe() + + def trace_value_from_frame(self) -> StringifyExpression: + """ + Trace the value of the function global variable from the frame. + + Returns: + StringifyExpression: The traced value of the function global variable. + + """ + fn_tracer = self.fn.tracker.trace_value_from_frame() + return StringifyExpression( + f"{{}}.__globals__['{self.name}']", + [fn_tracer], + union_free_vars(fn_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"FunctionGlobalTracker(fn={self.fn}, name={self.name})" + + +class FunctionClosureTracker(Tracker): + """ + A tracker class that represents a function closure variable. + + Args: + fn: The FunctionVariable object. + idx: The index of the closure variable. + + """ + + def __init__(self, fn: FunctionVariable, idx: int): + super().__init__([fn]) + self.fn = fn + self.idx = idx + + def gen_instructions(self, codegen: PyCodeGen): + """ + Generate bytecode instructions to trace the value of the function closure variable. + + Args: + codegen: The PyCodeGen object used to generate bytecode. + + """ + self.fn.tracker.gen_instructions(codegen) + codegen.gen_load_attr("__closure__") + codegen.gen_load_const(self.idx) + codegen.gen_subscribe() + codegen.gen_load_attr("cell_contents") + + def trace_value_from_frame(self): + """ + Trace the value of the function closure variable from the frame. + + Returns: + The traced value of the function closure variable. + + """ + fn_tracer = self.fn.tracker.trace_value_from_frame() + return StringifyExpression( + f"{{}}.__closure__[{self.idx}].cell_contents", + [fn_tracer], + union_free_vars(fn_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"FunctionClosureTracker(fn={self.fn}, idx={self.idx})" + + +@contextlib.contextmanager +def signature_clear_guard(fn, name): + if not hasattr(fn, name): + yield + else: + saved_attr = getattr(fn, name) + delattr(fn, name) + yield + setattr(fn, name, saved_attr) + + +class OpcodeInlineExecutor(OpcodeExecutorBase): + """ + A class that represents an executor for inlined opcode operations. + + Args: + fn_variable: The function variable. + + """ + + def __init__( + self, + fn_variable: FunctionVariable, + *args, + **kwargs, + ): + self._fn_var = fn_variable + self.return_value: VariableBase | None = None + self._fn_value = fn_variable.value + super().__init__(fn_variable.get_code(), fn_variable.graph) + self._name = "Inline" + self._prepare_locals(*args, **kwargs) + self._prepare_closure() + + def _handle_comps(self): + is_comp = any( + x in self._fn_value.__name__ + for x in ['<listcomp>', '<dictcomp>', '<genexpr>'] + ) + if not is_comp: + return + pattern = r'implicit\d+' + for name in list(self._locals.keys()): + if re.match(pattern, name): + self._locals[name.replace('implicit', '.')] = self._locals[name] + + def _prepare_locals(self, *args, **kwargs): + """ + Prepare local variables for execution by adding them to the locals dictionary. + + """ + from .variables import VariableBase, VariableFactory + + # temparay clear the fn.__signature__ to avoid signature check error + with signature_clear_guard( + self._fn_value, "__signature__" + ), signature_clear_guard(self._fn_value, "__wrapped__"): + sig = inspect.signature(self._fn_value) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + for name, value in bound_args.arguments.items(): + assert name in sig.parameters + # Convert varargs and kwargs to Variable + if sig.parameters[name].kind == inspect.Parameter.VAR_POSITIONAL: + tracker = DummyTracker(value) + elif sig.parameters[name].kind == inspect.Parameter.VAR_KEYWORD: + tracker = DummyTracker(list(value.values())) + # Convert default args to Variable + elif not isinstance(value, VariableBase): + tracker = ConstTracker(value) + else: + tracker = value.tracker + value = VariableFactory.from_value(value, self._graph, tracker) + self._locals[name] = value + + self._handle_comps() + + log( + 5, f"[INLINE CALL] {self._code.co_name} with locals: ", self._locals + ) + + def _prepare_closure(self): + """ + Prepare closure variables for execution by adding them to the closure list. + + """ + from .variables import VariableFactory + + closure = self._fn_var.get_py_value().__closure__ + for name in self._code.co_cellvars + self._code.co_freevars: + # create a cell for each variable. + self._cells[name] = CellVariable() # put in cells. + if name in self._locals: + self._cells[name].set_value(self._locals[name]) + + if closure is None: + return + assert len(closure) == len(self._code.co_freevars) + for idx, (name, cell) in enumerate( + zip(self._code.co_freevars, closure) + ): + value = cell.cell_contents + value = VariableFactory.from_value( + value, self._graph, FunctionClosureTracker(self._fn_var, idx) + ) + # wrapped by a CellVariable + if not isinstance(value, CellVariable): + value = CellVariable(value) + self._cells[name] = value + + @event_register("OpcodeInlineExecutor: _prepare_virtual_env", event_level=2) + def _prepare_virtual_env(self): + """ + Prepare the virtual environment for execution by adding variables from globals, builtins, and constants. + + """ + from .variables import VariableFactory + + self._globals = FunctionGlobalVariable( + self._fn_var, + self._fn_value.__globals__, + self._graph, + DanglingTracker(), + ) + + self._builtins = self._graph._builtins + + # prepare consts + for value in self._code.co_consts: + self._co_consts.append( + VariableFactory.from_value( + value, self._graph, ConstTracker(value) + ) + ) + + def inline_call(self) -> VariableBase: + """ + Execute the inline call of the function. + """ + self.run() + assert self.return_value is not None + return self.return_value + + def RETURN_VALUE(self, instr: Instruction): + assert ( + len(self.stack) == 1 + ), f"Stack must have one element, but get {len(self.stack)} elements." + self.return_value = self.stack.pop() + return Stop(state="Return") + + def _break_graph_in_jump(self, result, instr: Instruction): + """ + Helper method to raise a BreakGraphError when breaking the graph in a jump operation. + + Args: + result: The result of the operation. + instr (Instruction): The jump instruction. + """ + raise BreakGraphError( + "OpcodeInlineExecutor want call _break_graph_in_jump." + ) + + def _create_resume_fn(self, index: int, stack_size: int = 0): + """ + Helper method to create a resume function for the executor. + + Args: + index (int): The index of the instruction to resume execution from. + stack_size (int, optional): The size of the stack. Defaults to 0. + """ + raise BreakGraphError("_create_resume_fn.") + + def FOR_ITER(self, instr: Instruction): + iterator = self.stack.top + assert isinstance(iterator, IterVariable) + + self._graph.add_global_guarded_variable(iterator) + + # simplely get next + if isinstance( + iterator, + SequenceIterVariable, + ): + try: + self.stack.push(iterator.next()) + except StopIteration: + self.stack.pop() + assert isinstance(instr.jump_to, Instruction) + self._lasti = self.indexof(instr.jump_to) + + else: + self._graph.remove_global_guarded_variable(iterator) + raise BreakGraphError( + f"Found {iterator.__class__.__name__} as iterator." + ) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py new file mode 100644 index 00000000000000..3e2032dcc3a800 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -0,0 +1,1072 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This class is used for abstract code generation: +# We only need to care about what type of bytecode our code needs to generate, +# without worrying about the subscripts of bytecode instructions in the code option. + +from __future__ import annotations + +import random +import sys +import types +from functools import cached_property +from typing import TYPE_CHECKING + +import opcode + +import paddle + +from ...utils import ( + FallbackError, + InnerError, + OrderedSet, + ResumeFnNameFactory, + is_clean_code, + list_contain_by_id, + list_find_index_by_id, + no_eval_frame, +) +from ..instruction_utils import ( + analysis_inputs, + calc_stack_effect, + gen_instr, + get_instructions, + instrs_info, + modify_instrs, + modify_vars, +) +from ..instruction_utils.opcode_info import ( + PYOPCODE_CACHE_SIZE, + UNCONDITIONAL_JUMP, + JumpDirection, + PopJumpCond, +) +from .instr_flag import CALL_FUNCTION_EX_FLAG + +CODE_NAME_RNG = random.Random(2023) + +if TYPE_CHECKING: + from typing import Any + + from ..instruction_utils import Instruction + + +def get_pycode_attributes() -> list[str]: + """ + Returns a list of attribute names for PyCodeObject. + NOTE(SigureMo): The order should consistent with signature specified in code_doc + 3.8: https://github.com/python/cpython/blob/3.8/Objects/codeobject.c#L416-L421 + 3.10: https://github.com/python/cpython/blob/3.10/Objects/codeobject.c#L523-L543 + 3.11: https://github.com/python/cpython/blob/3.11/Objects/codeobject.c#L1494-L1516 + + Returns: + list[str]: The attribute names for PyCodeObject. + """ + pycode_attributes = [ + "co_argcount", + "co_posonlyargcount", + "co_kwonlyargcount", + "co_nlocals", + "co_stacksize", + "co_flags", + "co_code", + "co_consts", + "co_names", + "co_varnames", + "co_filename", + "co_name", + ] + if sys.version_info >= (3, 11): + pycode_attributes.append("co_qualname") + pycode_attributes.append("co_firstlineno") + if sys.version_info >= (3, 10): + pycode_attributes.append("co_linetable") + else: + pycode_attributes.append("co_lnotab") + if sys.version_info >= (3, 11): + pycode_attributes.append("co_exceptiontable") + pycode_attributes += [ + "co_freevars", + "co_cellvars", + ] + return pycode_attributes + + +PYCODE_ATTRIBUTES = get_pycode_attributes() + + +def gen_code_options(code: types.CodeType) -> dict[str, Any]: + """ + Generates a dictionary of code options for the given code object. + + Args: + code (types.CodeType): The code object. + + Returns: + dict[str, any]: The code options. + """ + code_options = {} + for k in PYCODE_ATTRIBUTES: + val = getattr(code, k) + if isinstance(val, tuple): + val = list(val) + code_options[k] = val + + return code_options + + +def gen_new_opcode( + instrs: list[Instruction], code_options: dict[str, Any], keys: list[str] +) -> types.CodeType: + """ + Generates a new code object with the given instructions, code options, and keys. + + Args: + instrs (list[Instruction]): The instructions for the new code object. + code_options (dict[str, any]): The code options for the new code object. + keys (list[str]): The keys to specify the order of code options. + + Returns: + types.CodeType: The new code object. + """ + bytecode, linetable = assemble(instrs, code_options["co_firstlineno"]) + if sys.version_info >= (3, 10): + # Python deprecated co_lnotab in 3.10, use co_linetable instead + # https://peps.python.org/pep-0626/ + code_options["co_linetable"] = linetable + else: + code_options["co_lnotab"] = linetable + code_options["co_code"] = bytecode + code_options["co_nlocals"] = len(code_options["co_varnames"]) + code_options["co_stacksize"] = stacksize(instrs) + if sys.version_info >= (3, 11): + # TODO: generate 3.11 exception table + code_options["co_exceptiontable"] = bytes([]) + for key, val in code_options.items(): + if isinstance(val, list): + code_options[key] = tuple(val) + # code_options is a dict, use keys to makesure the input order + return types.CodeType(*[code_options[k] for k in keys]) + + +def assemble( + instructions: list[Instruction], firstlineno: int +) -> tuple[bytes, bytes]: + """ + Assembles a list of instructions into bytecode and lnotab. + + Args: + instructions (list[Instruction]): The list of instructions to assemble. + firstlineno (int): The starting line number. + + Returns: + tuple[bytes, bytes]: The assembled bytecode and lnotab. + """ + code = [] + linetable = [] + + calc_linetable, update_cursor = create_linetable_calculator(firstlineno) + + for instr in instructions: + # set linetable, Python 3.11 need to set linetable for each instruction + if instr.starts_line is not None or sys.version_info >= (3, 11): + linetable.extend(calc_linetable(instr.starts_line, len(code))) + update_cursor(instr.starts_line, len(code)) + + # get bytecode + arg = instr.arg or 0 + code.extend((instr.opcode, arg & 0xFF)) + # fill CACHE + for _ in range(get_instruction_size(instr) // 2 - 1): + code.extend((0, 0)) + + if sys.version_info >= (3, 11): + # End hook for Python 3.11 + linetable.extend(calc_linetable(None, len(code))) + elif sys.version_info >= (3, 10): + # End hook for Python 3.10 + linetable.extend(calc_linetable(0, len(code))) + + return bytes(code), bytes(linetable) + + +def to_byte(num): + """ + Converts a negative number to an unsigned byte. + + Args: + num (int): The number to convert. + + Returns: + int: The converted unsigned byte. + """ + if num < 0: + num += 256 + return num + + +def get_instruction_size(instr: Instruction) -> int: + cache_size = 0 + if sys.version_info >= (3, 11): + cache_size = PYOPCODE_CACHE_SIZE.get(instr.opname, 0) + return 2 * (cache_size + 1) + + +def create_linetable_calculator(firstlineno: int): + """ + Creates a line table calculator function. + + Args: + firstlineno (int): The starting line number. + + Returns: + Callable: The line table calculator function. + """ + cur_lineno = firstlineno + cur_bytecode = 0 + line_offset = 0 # For Python 3.10 + + def update_cursor(starts_line: int | None, code_length: int): + nonlocal cur_lineno, cur_bytecode + cur_bytecode = code_length + if starts_line is not None: + cur_lineno = starts_line + + def calc_lnotab(starts_line: int, code_length: int): + """ + Calculates the lnotab for Python 3.8 and 3.9. + https://github.com/python/cpython/blob/3.9/Objects/lnotab_notes.txt + + Args: + starts_line (int): The line number where the instruction starts. + code_length (int): The length of the code. + + Returns: + list[int]: The lnotab. + """ + nonlocal cur_lineno, cur_bytecode + line_offset = starts_line - cur_lineno + byte_offset = code_length - cur_bytecode + result = [] + + while line_offset or byte_offset: + line_offset_step = min(max(line_offset, -128), 127) + byte_offset_step = min(max(byte_offset, 0), 255) + result.extend((byte_offset_step, to_byte(line_offset_step))) + line_offset -= line_offset_step + byte_offset -= byte_offset_step + return result + + def calc_linetable_py310(starts_line: int, code_length: int): + """ + Calculates the linetable for Python 3.10. + https://github.com/python/cpython/blob/3.10/Objects/lnotab_notes.txt + + Args: + starts_line (int): The line number where the instruction starts. + code_length (int): The length of the code. + + Returns: + list[int]: The linetable. + """ + nonlocal cur_lineno, cur_bytecode, line_offset + byte_offset = code_length - cur_bytecode + result = [] + while line_offset or byte_offset: + line_offset_step = min(max(line_offset, -127), 127) + byte_offset_step = min(max(byte_offset, 0), 254) + result.extend((byte_offset_step, to_byte(line_offset_step))) + line_offset -= line_offset_step + byte_offset -= byte_offset_step + line_offset = starts_line - cur_lineno + return result + + def _encode_varint(num: int): + """ + Encode unsigned integer into variable-length format. + """ + continue_flag = 0b01 << 6 + stop_flag = 0b00 << 6 + while num >= 0x40: + yield (num & 0x3F) | continue_flag + num >>= 6 + yield num | stop_flag + + def _encode_svarint(num: int): + """ + Encode signed integer into variable-length format. + """ + unsigned_value = (((-num) << 1) | 1) if num < 0 else (num << 1) + yield from _encode_varint(unsigned_value) + + def _encode_bytecode_to_entries_py311(line_offset: int, byte_offset: int): + if not byte_offset: + return [] + if 0 < byte_offset <= 8: + entry_head = 0b1_1101_000 | (byte_offset - 1) + return [entry_head, *list(_encode_svarint(line_offset))] + return [ + *_encode_bytecode_to_entries_py311(line_offset, 8), + *_encode_bytecode_to_entries_py311(line_offset, byte_offset - 8), + ] + + def calc_linetable_py311(starts_line: int | None, code_length: int): + """ + Calculates the linetable for Python 3.11. + https://github.com/python/cpython/blob/3.11/Objects/locations.md + + Args: + starts_line (int): The line number where the instruction starts. + code_length (int): The length of the code. + + Returns: + list[int]: The linetable. + """ + nonlocal cur_lineno, cur_bytecode + line_offset = starts_line - cur_lineno if starts_line is not None else 0 + byte_offset = (code_length - cur_bytecode) // 2 + return _encode_bytecode_to_entries_py311(line_offset, byte_offset) + + if sys.version_info >= (3, 11): + return calc_linetable_py311, update_cursor + elif sys.version_info >= (3, 10): + return calc_linetable_py310, update_cursor + else: + return calc_lnotab, update_cursor + + +def compile_exception_table(): + """Compile the exception table, it is used for Python 3.11+. + See https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt + """ + # TODO + ... + + +def stacksize(instructions: list[Instruction]) -> float: + """ + Calculates the maximum stack size before each opcode is called. + + Args: + instructions (list[Instruction]): The list of instructions. + + Returns: + int: The maximum stack size. + """ + max_stack = [float("-inf")] * len(instructions) + + max_stack[0] = 0 + + queue = [] + queue.append(0) + + def update_stacksize(lasti: int, nexti: int, stack_effect: int): + """ + Updates the maximum stack size. + + Args: + lasti (int): The index of the last instruction. + nexti (int): The index of the next instruction. + stack_effect (int): The effect on the stack size. + + Returns: + None + """ + old_max = max_stack[nexti] + max_stack[nexti] = max( + max_stack[nexti], max_stack[lasti] + stack_effect + ) + if old_max != max_stack[nexti]: + if nexti not in queue: # may be slow, we can use a flag. + queue.append(nexti) + + while len(queue) > 0: + idx = queue[0] + del queue[0] + instr = instructions[idx] + opname = instr.opname + if ( + idx + 1 < len(instructions) + and instr.opname not in UNCONDITIONAL_JUMP + ): + stack_effect = calc_stack_effect(instr, jump=False) + update_stacksize(idx, idx + 1, stack_effect) + + if instr.opcode in opcode.hasjabs or instr.opcode in opcode.hasjrel: + stack_effect = calc_stack_effect(instr, jump=True) + target_idx = instructions.index(instr.jump_to) + update_stacksize(idx, target_idx, stack_effect) + + # assert min(min_stack) >= 0 # min_stack may be a negative number when try: except is got. + return max(max_stack) + + +class PyCodeGen: + """Helper to create new code object""" + + def __init__( + self, frame: types.FrameType, disable_eval_frame: bool = False + ): + """ + Initializes a PyCodeGen object. + + Args: + frame: The frame to be translated. + disable_eval_frame (bool): Whether to disable the evaluation frame. Defaults to False. + """ + self._frame = frame + self._origin_code = frame.f_code + self._code_options = gen_code_options(self._origin_code) + self.update_code_name("", is_resumed_fn=False) + self._f_globals = frame.f_globals + self._instructions = [] + self.disable_eval_frame = disable_eval_frame + if self.disable_eval_frame: + self.gen_disable_eval_frame() + + def insert_prefix_instructions(self): + """ + Insert prefix instructions to the instruction list. + In Python 3.11+, we need to insert MAKE_CELL and COPY_FREE_VARS before the + first instruction. + The implementation is based on cpython implementation: + https://github.com/python/cpython/blob/f45ef5edabb1cc0748f3326e7114b8aaa0424392/Python/compile.c#L8177 + """ + prefixes = [] + if sys.version_info >= (3, 11): + if self._code_options["co_cellvars"]: + # Insert MAKE_CELL + name_map = list( + OrderedSet(self._code_options["co_varnames"]) + | OrderedSet(self._code_options["co_cellvars"]) + ) + + for i in self._code_options["co_cellvars"]: + idx: int = name_map.index(i) + prefixes.append(gen_instr("MAKE_CELL", arg=idx, argval=i)) + + if self._code_options["co_freevars"]: + n_freevars = len(self._code_options["co_freevars"]) + # Insert COPY_FREE_VARS + prefixes.append( + gen_instr( + "COPY_FREE_VARS", arg=n_freevars, argval=n_freevars + ) + ) + + # Insert RESUME + prefixes.append(gen_instr("RESUME", arg=0, argval=0)) + self._instructions[:] = prefixes + self._instructions + + def update_code_name(self, fn_name, is_resumed_fn): + if is_resumed_fn: + self._code_options[ + 'co_name' + ] = f"${fn_name}@{self._code_options['co_name'][1:]}" + else: + if self._code_options['co_name'].startswith("$"): + self._code_options[ + 'co_name' + ] = f"#{self._code_options['co_name']}" + elif not self._code_options['co_name'].startswith("#"): + random_number = int(CODE_NAME_RNG.random() * 100000000) + self._code_options[ + 'co_name' + ] = f"#{self._code_options['co_name']}_{hex(random_number & 0xFFFFF)[2:]:0>5}" + + def gen_pycode(self) -> types.CodeType: + """ + Generates a new pycode that is runnable. + + Returns: + CodeType: The generated code object. + """ + self.insert_prefix_instructions() + modify_instrs(self._instructions) + modify_vars(self._instructions, self._code_options) + new_code = gen_new_opcode( + self._instructions, self._code_options, PYCODE_ATTRIBUTES + ) + return new_code + + def gen_resume_fn_at( + self, index: int, stack_size: int = 0 + ) -> tuple[None | types.FunctionType, OrderedSet[str]]: + """ + Generates a resume function at the specified index in the instruction list. + + Args: + index (int): The index in the instruction list to generate the resume function. + stack_size (int): The size of the stack. Defaults to 0. + + Returns: + tuple: The resume function object and the inputs to the function. + + """ + self._instructions = get_instructions(self._origin_code) + # TODO(dev): could give an example code here? + if self._instructions[index].opname == 'RETURN_VALUE': + return None, OrderedSet() + inputs = analysis_inputs(self._instructions, index) + fn_name = ResumeFnNameFactory().next() + stack_arg_str = fn_name + '_stack_{}' + self._instructions = ( + [ + gen_instr('LOAD_FAST', argval=stack_arg_str.format(i)) + for i in range(stack_size) + ] + + [gen_instr('JUMP_FORWARD', jump_to=self._instructions[index])] + + self._instructions + ) + + self._code_options['co_argcount'] = len(inputs) + stack_size + # inputs should be at the front of the co_varnames + self._code_options['co_varnames'] = list( + [stack_arg_str.format(i) for i in range(stack_size)] + + list(inputs) + + [ + var_name + for var_name in self._origin_code.co_varnames + if var_name not in inputs + ] + ) + + self.update_code_name(fn_name, is_resumed_fn=True) + + new_code = self.gen_pycode() + if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: + raise FallbackError("Break graph in closure is not support.") + fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) + + return fn, inputs + + @cached_property + def global_null_variable(self): + from .variables.basic import NullVariable + + return NullVariable() + + def gen_disable_eval_frame(self): + """ + Generates instructions to disable the evaluation frame. + """ + if is_clean_code(): + return + self.gen_load_object( + paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("___old_eval_frame") + + def gen_enable_eval_frame(self): + """ + Generates instructions to enable the evaluation frame. + """ + if is_clean_code(): + return + self.gen_load_object( + paddle.framework.core.set_eval_frame, "paddle_set_eval_frame_fn" + ) + self.gen_load_fast("___old_eval_frame") + self.gen_call_function(1) + self.gen_pop_top() + + def gen_outputs_and_return(self, outputs): + for name in outputs: + self.gen_load(name) + self.gen_build_tuple(len(outputs)) + self.gen_return() + + def create_fn_with_inputs(self, inputs: list) -> types.FunctionType: + """ + Creates a function with specific input and output variables. + + Args: + inputs (list): The input variables. + + Returns: + function: The created function object. + """ + self._code_options['co_argcount'] = len(inputs) + self._code_options['co_varnames'] = list( + list(inputs) + + [ + var_name + for var_name in self._origin_code.co_varnames + if var_name not in inputs + ] + ) + fn_name = ResumeFnNameFactory().next() + self.update_code_name(fn_name, is_resumed_fn=True) + new_code = self.gen_pycode() + if len(new_code.co_freevars) + len(new_code.co_cellvars) > 0: + raise FallbackError("Break graph in closure is not support.") + fn = types.FunctionType(new_code, self._f_globals, new_code.co_name) + return fn + + def gen_load_const(self, value: Any): + """ + Generates instructions to load a constant value. + """ + # Python `list.index` will find an item equal to query, i.e. `query == item` + # returns a value of True. Since `1 == True`, this will result in an incorrect + # index. To avoid this problem, we use id for comparison. + if not list_contain_by_id(self._code_options["co_consts"], value): + self._code_options["co_consts"].append(value) + idx = list_find_index_by_id(self._code_options["co_consts"], value) + self._add_instr("LOAD_CONST", arg=idx, argval=value) + + def gen_print_log(self, message): + """print a log""" + import paddle + + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("old_eval_frame") + self.gen_load_global("print", push_null=True) + self.gen_load_const(message) + self.gen_call_function(1) + self.gen_pop_top() + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_fast("old_eval_frame") + self.gen_call_function(1) + self.gen_pop_top() + + def gen_dbg_function(self, dbg_fun): + """debug bytecode helper function. + Usage like: + def dbg_func(): + import inspect + import dis + print("dbg here.") + print(locals()) + frame = inspect.currentframe().f_back + code = (inspect.currentframe().f_back.f_code) + breakpoint() + print(inspect.currentframe().f_back.f_locals['y']) + + self.pycode_gen.gen_dbg_function(dbg_func) + """ + import paddle + + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_const(None) + self.gen_call_function(1) + self.gen_store_fast("old_eval_frame") + self.gen_load_object(dbg_fun, "dbg1") + self.gen_call_function(0) + self.gen_pop_top() + self.gen_load_object( + paddle.framework.core.set_eval_frame, "dbg_set_eval_frame" + ) + self.gen_load_fast("old_eval_frame") + self.gen_call_function(1) + self.gen_pop_top() + + @property + def cell_free_storage(self): + return ( + self._code_options["co_cellvars"] + + self._code_options["co_freevars"] + ) + + def gen_load(self, name): + if name in self.cell_free_storage: + self.gen_load_deref(name) + elif name in self._code_options["co_varnames"]: + self.gen_load_fast(name) + elif name in self._code_options["co_names"]: + self.gen_load_global(name, push_null=False) + else: + raise InnerError( + f"Want gen_load, but {name} can not found in code object." + ) + + def gen_store(self, name, code): + """ + Generate the bytecode for storing a variable identified by 'name' + in the corresponding symbol table and generate the appropriate + store code based on the symbol table analysis. + + Args: + name (str): The name of the variable. + """ + if name in (code.co_freevars + code.co_cellvars): + self.gen_store_deref(name) + elif name in code.co_varnames: + self.gen_store_fast(name) + elif name in code.co_names: + self.gen_store_global(name) + else: + raise InnerError( + f"Want gen_store, but {name} can not found in code object." + ) + + def gen_load_global(self, name, push_null=False): + """ + Generate the bytecode for loading a global variable. + + Args: + name (str): The name of the global variable. + """ + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + if sys.version_info >= (3, 11): + idx <<= 1 + if push_null: + idx |= 1 + self._add_instr("LOAD_GLOBAL", arg=idx, argval=name) + + def gen_load_object(self, obj, obj_name: str, push_null: bool = True): + """ + Generate the bytecode for loading an object. + + Args: + obj (Any): The object to load. + obj_name (str): The name of the object. + """ + + if obj_name not in self._f_globals: + self._f_globals[obj_name] = obj + self.gen_load_global(obj_name, push_null=push_null) + + def gen_load_null_variable(self): + """ + Generate the bytecode for loading a null variable. + """ + null_var = self.global_null_variable + self.gen_load_object(null_var, "___null_var", push_null=False) + + def gen_load_fast(self, name): + """ + Generate the bytecode for loading a local variable. + + Args: + name (str): The name of the local variable. + """ + if name not in self._code_options["co_varnames"]: + self._code_options["co_varnames"].append(name) + idx = self._code_options["co_varnames"].index(name) + self._add_instr("LOAD_FAST", arg=idx, argval=name) + + def gen_load_deref(self, name): + if name not in self.cell_free_storage: + self._code_options["co_freevars"].append(name) + if sys.version_info >= (3, 11): + # Because the co_varnames maybe changed after other codegen + # operations, we need re-calculate the index in modify_vars + idx = ( + self._code_options["co_varnames"] + + self._code_options["co_freevars"] + ).index(name) + else: + idx = self.cell_free_storage.index(name) + self._add_instr("LOAD_DEREF", arg=idx, argval=name) + + def gen_load_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("LOAD_ATTR", arg=idx, argval=name) + + def gen_store_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("STORE_ATTR", arg=idx, argval=name) + + def gen_delete_attr(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("DELETE_ATTR", arg=idx, argval=name) + + def gen_load_method(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("LOAD_METHOD", arg=idx, argval=name) + + def gen_delete_global(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("DELETE_GLOBAL", arg=idx, argval=name) + + def gen_import_name(self, name: str): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("IMPORT_NAME", arg=idx, argval=name) + + def gen_push_null(self): + if sys.version_info >= (3, 11): + self._add_instr("PUSH_NULL") + else: + # There is no PUSH_NULL bytecode before python3.11, so we push + # a NULL element to the stack through the following bytecode + self.gen_load_const(0) + self.gen_load_const(None) + self.gen_import_name('sys') + self.gen_store_fast('sys') + self.gen_load_fast('sys') + self.gen_load_method('getsizeof') + self.gen_pop_top() + + def gen_store_fast(self, name): + if name not in self._code_options["co_varnames"]: + self._code_options["co_varnames"].append(name) + idx = self._code_options["co_varnames"].index(name) + self._add_instr("STORE_FAST", arg=idx, argval=name) + + def gen_store_global(self, name): + if name not in self._code_options["co_names"]: + self._code_options["co_names"].append(name) + idx = self._code_options["co_names"].index(name) + self._add_instr("STORE_GLOBAL", arg=idx, argval=name) + + def gen_store_deref(self, name): + if name not in self.cell_free_storage: + self._code_options["co_freevars"].append(name) + if sys.version_info >= (3, 11): + # Because the co_varnames maybe changed after other codegen + # operations, we need re-calculate the index in modify_vars + idx = ( + self._code_options["co_varnames"] + + self._code_options["co_freevars"] + ).index(name) + else: + idx = self.cell_free_storage.index(name) + self._add_instr("STORE_DEREF", arg=idx, argval=name) + + def gen_store_subscr(self): + self._add_instr("STORE_SUBSCR") + + def gen_subscribe(self): + self._add_instr("BINARY_SUBSCR") + + def gen_build_tuple(self, count): + self._add_instr("BUILD_TUPLE", arg=count, argval=count) + + def gen_build_list(self, count): + self._add_instr("BUILD_LIST", arg=count, argval=count) + + def gen_build_map(self, count): + self._add_instr("BUILD_MAP", arg=count, argval=count) + + def gen_build_slice(self, argc): + self._add_instr("BUILD_SLICE", arg=argc, argval=argc) + + def gen_unpack_sequence(self, count): + self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count) + + def gen_call_function(self, argc=0): + if sys.version_info >= (3, 11): + self._add_instr("PRECALL", arg=argc, argval=argc) + self._add_instr("CALL", arg=argc, argval=argc) + else: + self._add_instr("CALL_FUNCTION", arg=argc, argval=argc) + + def gen_call_function_ex(self, has_kwargs): + flag = 0 + if has_kwargs: + flag |= CALL_FUNCTION_EX_FLAG.CFE_HAS_KWARGS + self._add_instr("CALL_FUNCTION_EX", arg=flag, argval=flag) + + def gen_call_method(self, argc=0): + if sys.version_info >= (3, 11): + self._add_instr("PRECALL", arg=argc, argval=argc) + self._add_instr("CALL", arg=argc, argval=argc) + else: + self._add_instr("CALL_METHOD", arg=argc, argval=argc) + + def gen_kw_names(self, kw_names: tuple[str, ...] | None): + if kw_names is None: + return + if sys.version_info < (3, 11): + raise InnerError("gen_kw_names is not supported before python3.11") + if kw_names not in self._code_options["co_consts"]: + self._code_options["co_consts"].append(kw_names) + idx = self._code_options["co_consts"].index(kw_names) + self._add_instr("KW_NAMES", arg=idx, argval=kw_names) + + def gen_pop_top(self): + self._add_instr("POP_TOP") + + def gen_rot_n(self, n): + if n <= 1: + return + if sys.version_info >= (3, 11): + for i in range(n, 1, -1): + self._add_instr("SWAP", arg=i) + elif sys.version_info >= (3, 10): + self._add_instr("ROT_N", arg=n) + else: + if n <= 4: + self._add_instr("ROT_" + ["TWO", "THREE", "FOUR"][n - 2]) + else: + + def rot_n_fn(n): + vars = [f"var{i}" for i in range(n)] + rotated = reversed(vars[-1:] + vars[:-1]) + fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") + fn = no_eval_frame(fn) + fn.__name__ = f"rot_{n}_fn" + return fn + + self.gen_build_tuple(n) + self.gen_load_const(rot_n_fn(n)) + self.gen_rot_n(2) + self._add_instr("CALL_FUNCTION_EX", arg=0) + self.gen_unpack_sequence(n) + + def gen_shift_n(self, s: int, n: int): + """ + Generate the bytecode for shifting the stack. + + Args: + s (int): Steps to shift. + n (int): The number of elements to shift. + """ + if s == 0 or n <= 1: + return + + # NOTE(zrr1999): right shift s steps is equal to left shift n-s steps + if abs(s) > n // 2: + new_s = s - n if s > 0 else s + n + self.gen_shift_n(new_s, n) + return + if s > 0: + # NOTE: s=1, n=3 [1,2,3,4,5] -> [1,2,5,3,4] + # s=2, n=3 [1,2,3,4,5] -> [1,2,4,5,3] + if s == 1: + self.gen_rot_n(n) + else: + self.gen_rot_n(n) + self.gen_shift_n(s - 1, n) + + else: # s < 0 + if sys.version_info >= (3, 11): + # NOTE: s=-1, n=3 [1,2,3,4,5] -> [1,2,4,5,3] + if s == -1: + for i in range(2, n + 1): + self._add_instr("SWAP", arg=i) + else: + self.gen_shift_n(-1, n) + self.gen_shift_n(s + 1, n) + else: + raise NotImplementedError( + "shift_n is not supported before python3.11" + ) + + def gen_swap(self, n): + if sys.version_info >= (3, 11): + self._add_instr("SWAP", arg=n) + else: + raise NotImplementedError("swap is not supported before python3.11") + + def gen_jump( + self, + jump_to: Instruction | None = None, + *, + direction: JumpDirection = JumpDirection.FORWARD, + ) -> Instruction: + if sys.version_info >= (3, 11): + return self._add_instr(f"JUMP_{direction.value}", jump_to=jump_to) + else: + return self._add_instr("JUMP_ABSOLUTE", jump_to=jump_to) + + def gen_pop_jump( + self, + jump_to: Instruction | None = None, + *, + direction: JumpDirection = JumpDirection.FORWARD, + suffix: PopJumpCond = PopJumpCond.NONE, + ) -> Instruction: + if sys.version_info >= (3, 11): + return self._add_instr( + f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to + ) + else: + return self._add_instr( + f"POP_JUMP_IF_{suffix.value}", jump_to=jump_to + ) + + def gen_return(self): + self._add_instr("RETURN_VALUE") + + def gen_get_iter(self): + self._add_instr("GET_ITER") + + def add_pure_instructions(self, instructions): + """ + add instructions and do nothing. + """ + self._instructions.extend(instructions) + + def _add_instr(self, *args, **kwargs): + instr = gen_instr(*args, **kwargs) + self._instructions.append(instr) + return instr + + def _insert_instr(self, index, *args, **kwargs): + instr = gen_instr(*args, **kwargs) + self._instructions.insert(index, instr) + + def pprint(self): + print('\n'.join(instrs_info(self._instructions))) + + def extend_instrs(self, instrs): + self._instructions.extend(instrs) + + def pop_instr(self): + self._instructions.pop() + + def replace_null_variable(self): + """ + Replace all NullVariables in the bytecode. + + Returns: + Optional[Tuple[Any, Callable]]: The new code object and its guard function, or None if no dummy variables are found. + """ + from .variables.basic import NullVariable + + instructions = get_instructions(self._origin_code) + has_null_variable = False + for instr in instructions: + if ( + instr.opname == 'LOAD_FAST' + and instr.argval in self._frame.f_locals.keys() + and isinstance(self._frame.f_locals[instr.argval], NullVariable) + ): + has_null_variable = True + self._frame.f_locals[instr.argval].reconstruct(self) + else: + self.add_pure_instructions([instr]) + + if has_null_variable: + new_code = self.gen_pycode() + return new_code + else: + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/side_effects.py b/python/paddle/jit/sot/opcode_translator/executor/side_effects.py new file mode 100644 index 00000000000000..f9f8fc20141a13 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/side_effects.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar + +from .mutable_data import MutableData +from .variables import VariableBase + +if TYPE_CHECKING: + from .mutable_data import DataGetter + from .pycode_generator import PyCodeGen + + MutableDataT = TypeVar("MutableDataT", bound=MutableData) + + +class SideEffectsState(NamedTuple): + data_id_to_proxy: dict[int, MutableData] + proxy_variables: list[VariableBase] + mutable_variables: list[VariableBase] + proxy_versions: list[int] + mutable_attrs: list[dict[str, Any]] + + +class SideEffects: + def __init__(self): + self.data_id_to_proxy: dict[int, MutableData] = {} + self.proxy_variables: list[VariableBase] = [] + self.mutable_variables: list[VariableBase] = [] + + def record_proxy_variable(self, variable: VariableBase): + if variable not in self.proxy_variables: + self.proxy_variables.append(variable) + + def record_mutable_variable(self, variable: VariableBase): + if variable not in self.mutable_variables: + self.mutable_variables.append(variable) + + def get_proxy( + self, + proxy_type: type[MutableDataT], + data: Any, + getter: DataGetter, + ) -> MutableDataT: + data_id = id(data) + if data_id not in self.data_id_to_proxy: + self.data_id_to_proxy[data_id] = proxy_type(data, getter) + return self.data_id_to_proxy[data_id] # type: ignore + + def get_state(self): + return SideEffectsState( + self.data_id_to_proxy.copy(), + self.proxy_variables.copy(), + self.mutable_variables.copy(), + [proxy.version for proxy in self.data_id_to_proxy.values()], + [ + {attr: getattr(var, attr)} + for var in self.mutable_variables + for attr in var.mutable_attrs + ], + ) + + def restore_state(self, state: SideEffectsState): + self.data_id_to_proxy = state.data_id_to_proxy + self.proxy_variables = state.proxy_variables + self.mutable_variables = state.mutable_variables + # NOTE(SigureMo): We can use the `strict=True` option in zip after + # Python 3.10. + assert len(self.data_id_to_proxy.values()) == len( + state.proxy_versions + ), "proxy_versions length not match" + assert len(self.mutable_variables) == len( + state.mutable_attrs + ), "mutable_attrs length not match" + + for proxy, version in zip( + self.data_id_to_proxy.values(), state.proxy_versions + ): + proxy.rollback(version) + + for (variable, attr), attr_dict in zip( + ( + (var, attr) + for var in self.mutable_variables + for attr in var.mutable_attrs + ), + (attr_dict for attr_dict in state.mutable_attrs), + ): + setattr(variable, attr, attr_dict[attr]) + + +class SideEffectRestorer: + def pre_gen(self, codegen: PyCodeGen): + raise NotImplementedError() + + def post_gen(self, codegen: PyCodeGen): + raise NotImplementedError() + + +class DictSideEffectRestorer(SideEffectRestorer): + """ + old_dict.clear() + old_dict.update(new_dict) + """ + + def __init__(self, var: VariableBase): + super().__init__() + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # Reference to the original dict. + # load old_dict.update and new_dict to stack. + self.var.reconstruct(codegen) + codegen.gen_load_method("update") + # Generate dict by each key-value pair. + self.var.reconstruct(codegen, use_tracker=False) + # load old_dict.clear to stack. + self.var.reconstruct(codegen) + codegen.gen_load_method("clear") + + def post_gen(self, codegen: PyCodeGen): + # Call methods to apply side effects. + codegen.gen_call_method(0) # call clear + codegen.gen_pop_top() + codegen.gen_call_method(1) # call update + codegen.gen_pop_top() + + +class ListSideEffectRestorer(SideEffectRestorer): + """ + old_list[:] = new_list + """ + + def __init__(self, var: VariableBase): + super().__init__() + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # Reference to the original list. + # load new_list to stack. + self.var.reconstruct(codegen, use_tracker=False) + # load old_list[:] to stack. + self.var.reconstruct(codegen) + codegen.gen_load_const(None) + codegen.gen_load_const(None) + codegen.gen_build_slice(2) + + def post_gen(self, codegen: PyCodeGen): + # Call STROE_SUBSCR to apply side effects. + codegen.gen_store_subscr() + + +class GlobalSetSideEffectRestorer(SideEffectRestorer): + """ + global_var = new_value + """ + + def __init__(self, name: str, var: VariableBase): + super().__init__() + self.name = name + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + self.var.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_store_global(self.name) + + +class GlobalDelSideEffectRestorer(SideEffectRestorer): + """ + del global_var + """ + + def __init__(self, name: str): + super().__init__() + self.name = name + + def pre_gen(self, codegen: PyCodeGen): + # do nothing + ... + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_delete_global(self.name) + + +class ObjSetSideEffectRestorer(SideEffectRestorer): + """ + obj.attr = new_value + """ + + def __init__(self, obj: VariableBase, name: str, var: VariableBase): + super().__init__() + self.obj = obj + self.name = name + self.var = var + + def pre_gen(self, codegen: PyCodeGen): + # value + self.var.reconstruct(codegen) + # obj + self.obj.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_store_attr(self.name) + + +class ObjDelSideEffectRestorer(SideEffectRestorer): + """ + del obj.attr + """ + + def __init__(self, obj: VariableBase, name: str): + super().__init__() + self.obj = obj + self.name = name + + def pre_gen(self, codegen: PyCodeGen): + self.obj.reconstruct(codegen) + + def post_gen(self, codegen: PyCodeGen): + codegen.gen_delete_attr(self.name) diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker.py b/python/paddle/jit/sot/opcode_translator/executor/tracker.py new file mode 100644 index 00000000000000..c085e14b5b3824 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker.py @@ -0,0 +1,387 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +import sys +from typing import TYPE_CHECKING + +from ...utils import InnerError, NameGenerator +from .guard import StringifyExpression, union_free_vars + +if TYPE_CHECKING: + from typing import Sequence + + from .pycode_generator import PyCodeGen + from .variables import VariableBase + + +class Tracker: + """ + Tracker is a base class responsible for tracking variables or objects in Python code. + It is used to identify how a variable is derived from the initial state of the frame. + + Args: + inputs: The list of variables to be tracked. + + Note: + It serves as an abstract class and should not be instantiated directly. + """ + + inputs: Sequence[VariableBase] + name_generator = NameGenerator("tracker_") + + def __init__(self, inputs: Sequence[VariableBase], changed: bool = False): + self.inputs = inputs + self.changed = changed + self.id = Tracker.name_generator.next() + + def gen_instructions(self, codegen: PyCodeGen) -> None: + """ + Generate instructions based on the tracked variables. + + Args: + codegen (PyCodeGen): An instance of PyCodeGen to generate instructions. + """ + raise NotImplementedError() + + # TODO(xiongkun): trace_value_from_frame is not a good name, it should be more related to guard but not tracable. + def trace_value_from_frame(self) -> StringifyExpression: + """ + Trace the value of the tracked variables from the frame. It used for generating the guard. + + Returns: + The value of the tracked variables. + """ + raise NotImplementedError() + + def is_traceable(self) -> bool: + """ + Determine if all the tracked variables can be traced from the frame. + + Returns: + bool, True if all tracked variables are traceable, False otherwise. + """ + if self.changed: + return False + for input in self.inputs: + if not input.tracker.is_traceable(): + return False + return True + + def need_guard(self) -> bool: + return self.is_traceable() + + +class DummyTracker(Tracker): + """ + DummyTracker is a subclass of Tracker that specifically tracks variables cannot be reproduced from the frame. + It is mostly generated by complex operations (instructions). + + Args: + inputs (list[VariableBase]): The input variables associated with the generated variables. + """ + + def __init__(self, inputs: Sequence[VariableBase]): + super().__init__(inputs) + + def gen_instructions(self, codegen: PyCodeGen): + raise InnerError("DummyTracker has no instructions") + + def trace_value_from_frame(self): + raise InnerError("DummyTracker can't trace value from frame") + + def is_traceable(self): + return False + + def __repr__(self) -> str: + return f"DummyTracker(num_inputs={len(self.inputs)})" + + def need_guard(self) -> bool: + return False + + +class DanglingTracker(Tracker): + """ + DanglingTracker is a subclass of Tracker that specifically tracks variables that are not in the frame. + Variables whose tracker is DanglingTracker should not be placed on the stack, except for NullVariable. + DanglingTracker is often used in conjunction with BuiltinVariable to reuse the dispatch mechanism. + + Examples: + >>> import operator + >>> from sot.opcode_translator.executor.variables import BuiltinVariable, ConstantVariable + >>> a = ConstantVariable.wrap_literal(1, None) + >>> b = ConstantVariable.wrap_literal(2, None) + >>> c = BuiltinVariable(operator.add, None, DanglingTracker())(a, b) + >>> c.value + 3 + """ + + def __init__(self): + super().__init__([]) + + def gen_instructions(self, codegen: PyCodeGen): + raise InnerError("DanglingTracker has no instructions") + + def trace_value_from_frame(self): + raise InnerError("DanglingTracker can't trace value from frame") + + def is_traceable(self): + return False + + def __repr__(self) -> str: + return "DanglingTracker()" + + +class LocalTracker(Tracker): + """ + LocalTracker is a subclass of Tracker that specifically tracks variables from f_locals of frame. + + Args: + name (str): The name of the variable in f_locals to be tracked. + """ + + def __init__(self, name: str): + super().__init__([]) + self.name = name + + def gen_instructions(self, codegen: PyCodeGen) -> None: + codegen.gen_load_fast(self.name) + + def trace_value_from_frame(self) -> StringifyExpression: + return StringifyExpression(f"frame.f_locals['{self.name}']", [], {}) + + def __repr__(self) -> str: + return f"LocalTracker(name={self.name})" + + +class CellTracker(LocalTracker): + def gen_instructions(self, codegen: PyCodeGen): + codegen.gen_load_deref(self.name) + + def trace_value_from_frame(self): + return StringifyExpression(f"frame.f_locals['{self.name}']", [], {}) + + def __repr__(self) -> str: + return f"CellTracker(name={self.name})" + + +class GlobalTracker(Tracker): + """ + GlobalTracker is a subclass of Tracker that specifically tracks variables from f_globals of frame. + + Args: + name (str): The name of the variable in f_globals to be tracked. + """ + + def __init__(self, name: str): + super().__init__([]) + self.name = name + + def gen_instructions(self, codegen: PyCodeGen) -> None: + codegen.gen_load_global(self.name, push_null=False) + + def trace_value_from_frame(self) -> StringifyExpression: + return StringifyExpression(f"frame.f_globals['{self.name}']", [], {}) + + def __repr__(self) -> str: + return f"GlobalTracker(name={self.name})" + + +class BuiltinTracker(Tracker): + """ + BuiltinTracker is a subclass of Tracker that specifically tracks variables from f_builtins of frame. + + Args: + name (str): The name of the variable in f_builtins to be tracked. + """ + + def __init__(self, name: str): + super().__init__([]) + self.name = name + + def gen_instructions(self, codegen: PyCodeGen) -> None: + codegen.gen_load_global(self.name, push_null=False) + + def trace_value_from_frame(self) -> StringifyExpression: + return StringifyExpression( + f"builtins.__dict__['{self.name}']", [], {"builtins": builtins} + ) + + def __repr__(self) -> str: + return f"BuiltinTracker(name={self.name})" + + +class ConstTracker(Tracker): + """ + ConstTracker is a subclass of Tracker that specifically tracks a constant value. + + Args: + value (Any): The value of the constant. + """ + + def __init__(self, value): + super().__init__([]) + self.value = value + + def gen_instructions(self, codegen: PyCodeGen): + codegen.gen_load_const(self.value) + + def trace_value_from_frame(self): + return StringifyExpression(f"{self.value!r}", [], {}) + + def __repr__(self) -> str: + return f"ConstTracker(value={self.value})" + + def need_guard(self) -> bool: + return False + + +class GetAttrTracker(Tracker): + """ + GetAttrTracker is a subclass of Tracker that specifically tracks the attribute access of an variable. + + Args: + obj (VariableBase): The object whose attribute is to be tracked. + attr (str): The attribute to be tracked. + """ + + def __init__(self, obj: VariableBase, attr: str, changed: bool = False): + super().__init__([obj], changed) + self.obj = obj + self.attr = attr + + def gen_instructions(self, codegen: PyCodeGen): + self.obj.tracker.gen_instructions(codegen) + codegen.gen_load_attr(self.attr) + + def trace_value_from_frame(self): + obj_tracer = self.obj.tracker.trace_value_from_frame() + if self.attr.isidentifier(): + expr = f"{{}}.{self.attr}" + else: + expr = f"getattr({{}}, '{self.attr}')" + return StringifyExpression( + expr, + [obj_tracer], + union_free_vars(obj_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"GetAttrTracker(attr={self.attr})" + + def need_guard(self) -> bool: + return self.is_traceable() and self.obj.tracker.need_guard() + + +class GetItemTracker(Tracker): + """ + GetItemTracker is a subclass of Tracker that specifically tracks item access of a container variable. + + It generates instructions and traces the item value from the frame. + + Args: + container_var (VariableBase): The container object whose item is to be tracked. + key: The key/index of the item to be tracked. + """ + + def __init__(self, container_var: VariableBase, key: object, changed=False): + super().__init__([container_var], changed) + self.container = container_var + self.key = key + + def gen_instructions(self, codegen: PyCodeGen): + self.container.tracker.gen_instructions(codegen) + if isinstance(self.key, slice): + codegen.gen_load_const(self.key.start) + codegen.gen_load_const(self.key.stop) + codegen.gen_load_const(self.key.step) + codegen.gen_build_slice(3) + else: + codegen.gen_load_const(self.key) + codegen.gen_subscribe() + + def trace_value_from_frame(self): + container_tracer = self.container.tracker.trace_value_from_frame() + return StringifyExpression( + f"{{}}[{self.key!r}]", + [container_tracer], + union_free_vars(container_tracer.free_vars), + ) + + def __repr__(self) -> str: + return f"GetItemTracker(key={self.key!r})" + + def need_guard(self) -> bool: + return self.is_traceable() and self.container.tracker.need_guard() + + +class GetIterTracker(Tracker): + """ + GetIterTracker is a subclass of Tracker that specifically tracks iteration of a variable. + + It generates instructions and traces the iterator from the frame. + + Args: + iter_source (VariableBase): The source variable to be iterated. + """ + + def __init__(self, iter_source: VariableBase): + super().__init__([iter_source]) + self.iter_source = iter_source + + def gen_instructions(self, codegen: PyCodeGen): + self.iter_source.tracker.gen_instructions(codegen) + codegen._add_instr("GET_ITER") + + def trace_value_from_frame(self): + iter_source_tracer = self.iter_source.tracker.trace_value_from_frame() + return StringifyExpression( + "iter({})", + [iter_source_tracer], + union_free_vars(iter_source_tracer.free_vars), + ) + + def __repr__(self) -> str: + return "GetIterTracker()" + + +class CreateLayerTracker(Tracker): + def __init__(self, layer_class, args, kwargs): + super().__init__([layer_class] + list(args) + list(kwargs.values())) + self.layer_class = layer_class + self.args = args + self.kwargs = kwargs + + def gen_instructions(self, codegen: PyCodeGen): + if sys.version_info >= (3, 11): + codegen.gen_push_null() + + self.layer_class.reconstruct(codegen) + for variable in self.args: + variable.reconstruct(codegen) + + if len(self.kwargs) == 0: + codegen.gen_call_function(argc=len(self.args)) + else: + codegen.gen_build_tuple(len(self.args)) + for k, v in self.kwargs.items(): + codegen.gen_load_const(k) + v.reconstruct(codegen) + codegen.gen_build_map(len(self.kwargs)) + codegen.gen_call_function_ex(has_kwargs=True) + + def __repr__(self) -> str: + return f"CreateLayerTracker(Layer={self.layer_class}, args={self.args}, kwargs={self.kwargs})" diff --git a/python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py b/python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py new file mode 100644 index 00000000000000..f132c34abcac16 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/tracker_viewer.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import queue +from typing import TYPE_CHECKING + +from .tracker import DummyTracker +from .variables import VariableBase + +SIR_GRAPH_CLUSTER_NAME = "cluster_sir_part" + +if TYPE_CHECKING: + import graphviz + + +def try_import_graphviz(): + try: + import graphviz + + return graphviz + except ImportError: + return None + + +def draw_variable(graph: graphviz.Digraph, var: VariableBase): + """ + Draw and colour a node in the graph. + + Args: + graph (graphviz.Digraph): The graph to draw the variable. + var (VariableBase): The variable to draw. + + Returns: + None + """ + # Draw Variable + graph.attr('node', shape='oval', style="filled", fillcolor='aliceblue') + graph.attr('edge', style='solid') + graph.node(var.id, str(var)) + + # Draw Tracker + tracker = var.tracker + graph.attr('node', shape='rect', style='filled', fillcolor='beige') + if isinstance(tracker, DummyTracker): + graph.attr('edge', style='dashed') + graph.attr('node', shape='rect', style='filled', fillcolor='goldenrod') + graph.node(tracker.id, str(tracker)) + + # Draw edge (Tracker -> Variable) + graph.edge(tracker.id, var.id) + + # Draw edge (Tracker inputs -> Tracker) + graph.attr('node', shape='oval', style="filled", fillcolor='cadetblue') + graph.attr('edge', style='solid') + for input in tracker.inputs: + graph.edge(input.id, tracker.id) + + +def view_tracker( + root_variables: list[VariableBase], filename: str, format: str +): + """ + Generates a graph visualization starting from the given root variables and save it to the given file. + + Args: + root_variables (list[VariableBase]): The root variables to start the visualization from. + filename (str): The name of the file used to save the results of the visualisation. + format (str): The format (e.g., `pdf`, `png` and 'svg' etc.) of the file to save the visualization to. + + Returns: + None + """ + # TODO(SigureMo): + # 1. Colorize the trackers + # 2. Highlight the user specific node, to speedup debug process + graphviz = try_import_graphviz() + if graphviz is None: + print("Cannot import graphviz, please install it first.") + return + + graph = graphviz.Digraph("graph", filename=filename, format=format) + visited = set() + var_queue = queue.Queue() + for var in root_variables: + var_queue.put(var) + + while not var_queue.empty(): + var = var_queue.get() + if var.id in visited: + continue + visited.add(var.id) + if isinstance(var.tracker, DummyTracker): + with graph.subgraph(name=SIR_GRAPH_CLUSTER_NAME) as sir_part: + sir_part.attr(color='green') + draw_variable(sir_part, var) + else: + draw_variable(graph, var) + for input in var.tracker.inputs: + if input not in var_queue.queue: + var_queue.put(input) + + graph.render(view=False) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py new file mode 100644 index 00000000000000..9eb10fb81bcd53 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py @@ -0,0 +1,1109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import operator +from functools import partial, reduce +from typing import TYPE_CHECKING + +import paddle + +from ...utils import BreakGraphError, FallbackError +from ...utils.magic_methods import ( + BINARY_OPS, + UNARY_OPS, + magic_method_builtin_dispatch, +) +from .dispatch_functions import ( + operator_in, + operator_is_none, + operator_is_not_none, + operator_not_in, + raise_break_graph_fn, + tensor_numel, +) +from .dispatcher import Dispatcher, optional +from .tracker import ConstTracker, DanglingTracker, DummyTracker +from .variables import ( + BuiltinVariable, + ConstantVariable, + ContainerVariable, + DictVariable, + EnumerateVariable, + ListVariable, + MapVariable, + NumpyVariable, + RangeVariable, + SliceVariable, + TupleVariable, + VariableBase, + VariableFactory, +) + +if TYPE_CHECKING: + from .variables import DataVariable, TensorVariable + + +def add_guard(var: VariableBase): + var.graph.add_global_guarded_variable(var) + return var + + +def raise_err_handle(error): + def inner(*args, **kwargs): + raise error + + return inner + + +# slice +Dispatcher.register( + slice, + ("VariableBase",), + lambda stop: SliceVariable( + slice(stop), + graph=stop.graph, + tracker=DummyTracker([stop]), + ), +) + +Dispatcher.register( + slice, + ("VariableBase", "VariableBase"), + lambda start, stop: SliceVariable( + slice(start, stop), + graph=stop.graph, + tracker=DummyTracker([start, stop]), + ), +) + +Dispatcher.register( + slice, + ("VariableBase", "VariableBase", "VariableBase"), + lambda start, stop, step: SliceVariable( + slice(start, stop, step), + graph=stop.graph, + tracker=DummyTracker([start, stop, step]), + ), +) + + +# iter +Dispatcher.register( + iter, + ("VariableBase",), + lambda variable: variable.get_iter(), +) + + +# in +Dispatcher.register( + operator_in, + ("VariableBase", "IterVariable"), + raise_err_handle(BreakGraphError("Codes like: `variable in iterator`.")), +) + +Dispatcher.register( + operator_in, + ("TensorVariable", "VariableBase"), + lambda left, right: ConstantVariable( + left.id + in [ + x.id + for x in right.get_py_value(allow_tensor=True) + if hasattr(x, "id") + ], + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +Dispatcher.register( + operator_in, + ("VariableBase", "VariableBase"), + lambda left, right: ConstantVariable( + left.get_py_value(allow_tensor=True) + in right.get_py_value(allow_tensor=True), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +Dispatcher.register( + operator_not_in, + ("VariableBase", "IterVariable"), + raise_err_handle( + BreakGraphError("Codes like: `variable not in iterator`.") + ), +) + +Dispatcher.register( + operator_not_in, + ("TensorVariable", "VariableBase"), + lambda left, right: ConstantVariable( + left.id + not in [ + x.id + for x in right.get_py_value(allow_tensor=True) + if hasattr(x, "id") + ], + left.graph, + tracker=DummyTracker([left, right]), + ), +) + +Dispatcher.register( + operator_not_in, + ("VariableBase", "VariableBase"), + lambda left, right: ConstantVariable( + left.get_py_value(allow_tensor=True) + not in right.get_py_value(allow_tensor=True), + left.graph, + tracker=DummyTracker([left, right]), + ), +) + + +# dict +Dispatcher.register( + dict, + (), + lambda: DictVariable( + {}, + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ), +) + +Dispatcher.register( + dict, + ("DictVariable",), + lambda var: var.copy(), +) + + +@Dispatcher.register_decorator(dict) +def dispatch_dict(var: ListVariable | TupleVariable): + res_dict = {} + length_var = BuiltinVariable(len, var.graph, DanglingTracker())(var) + getitem = BuiltinVariable(operator.getitem, var.graph, DanglingTracker()) + for index in range(length_var.get_py_value()): + index_value = getitem(var, index) + # check + assert isinstance(index_value, (ListVariable, TupleVariable)) + assert len(index_value) == 2 + # recombination + key = getitem(index_value, 0) + value = getitem(index_value, 1) + value.graph.add_global_guarded_variable(key) + res_dict.update({key.get_py_value(): value}) + return DictVariable(res_dict, var.graph, DummyTracker([var])) + + +@Dispatcher.register_decorator(dict.fromkeys) +def dispatch_dict_fromkeys(seq: ListVariable | TupleVariable, default: VariableBase = None): # type: ignore + if default is None: + default = ConstantVariable.wrap_literal(None, seq.graph) + res_dict = {} + getitem = BuiltinVariable(operator.getitem, seq.graph, DanglingTracker()) + for index in range(len(seq)): + index_value = getitem(seq, index) + seq.graph.add_global_guarded_variable(index_value) + res_dict.update({index_value.get_py_value(): default}) + return DictVariable(res_dict, seq.graph, DummyTracker([seq])) + + +Dispatcher.register( + dict.get, + ("DictVariable", "ConstantVariable", optional("VariableBase")), + lambda var, key, default=None: var.get(key.get_py_value(), default), +) +Dispatcher.register( + dict.keys, + ("DictVariable",), + lambda var: var.keys(), +) + +Dispatcher.register( + operator.not_, + ("VariableBase",), + lambda x: ConstantVariable( + not x.get_py_value(allow_tensor=False), x.graph, DummyTracker([x]) + ), +) + +Dispatcher.register( + dict.values, + ("DictVariable",), + lambda var: var.values(), +) +Dispatcher.register( + dict.items, + ("DictVariable",), + lambda var: var.items(), +) +Dispatcher.register( + dict.setdefault, + ("DictVariable", "ConstantVariable", optional("VariableBase")), + lambda var, key, default=None: var.setdefault(key.get_py_value(), default), +) +Dispatcher.register( + dict.update, + ("DictVariable", "DictVariable"), + lambda var, other: var.update(other), +) +Dispatcher.register( + dict.copy, + ("DictVariable",), + lambda var: var.copy(), +) +Dispatcher.register( + dict.clear, + ("DictVariable",), + lambda var: var.clear(), +) +Dispatcher.register( + dict.pop, + ("DictVariable", "ConstantVariable"), + lambda var, key: var.pop(key.get_py_value()), +) +Dispatcher.register( + dict.pop, + ("DictVariable", "ConstantVariable", "VariableBase"), + lambda var, key, default: var.pop(key.get_py_value(), default), +) +Dispatcher.register( + dict.popitem, + ("DictVariable",), + lambda var: var.popitem(), +) + +# tuple +Dispatcher.register( + tuple, + ("ContainerVariable",), + lambda var: TupleVariable( + tuple(var.get_wrapped_items()), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) +Dispatcher.register( + tuple, + ("SequenceIterVariable",), + lambda var: TupleVariable( + tuple(var.to_list()), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) +Dispatcher.register( + tuple.count, + ("TupleVariable", "VariableBase"), + lambda var, value: var.count(value), +) +Dispatcher.register( + tuple.index, + ("TupleVariable", "VariableBase"), + lambda var, value: var.index(value), +) + +# list +Dispatcher.register( + list, + (), + lambda: ListVariable( + [], + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ), +) + +Dispatcher.register( + list, + ("ContainerVariable",), + lambda var: ListVariable( + list(var.get_wrapped_items()), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) + +Dispatcher.register( + list, + ("IterVariable",), + lambda var: ListVariable( + var.to_list(), + graph=var.graph, + tracker=DummyTracker([var]), + ), +) +Dispatcher.register( + list.extend, + ("ListVariable", "ListVariable | TupleVariable"), + lambda var, other: var.extend(other), +) +Dispatcher.register( + list.append, + ("ListVariable", "VariableBase"), + lambda var, other: var.append(other), +) +Dispatcher.register( + list.insert, + ("ListVariable", "ConstantVariable", "VariableBase"), + lambda var, index, obj: var.insert(index.get_py_value(), obj), +) +Dispatcher.register( + list.remove, + ("ListVariable", "VariableBase"), + lambda var, other: var.remove(other), +) +Dispatcher.register( + list.pop, + ("ListVariable", optional("ConstantVariable")), + lambda var, index=None: var.pop(index), +) +Dispatcher.register( + list.clear, + ("ListVariable",), + lambda var: var.clear(), +) +Dispatcher.register( + list.sort, + ("ListVariable",), + lambda var: var.sort(), +) +Dispatcher.register( + list.reverse, + ("ListVariable",), + lambda var: var.reverse(), +) +Dispatcher.register( + list.copy, + ("ListVariable",), + lambda var: var.copy(), +) +Dispatcher.register( + list.count, + ("ListVariable", "VariableBase"), + lambda var, obj: var.count(obj), +) +Dispatcher.register( + list.index, + ("ListVariable", "VariableBase"), + lambda var, obj: var.index(obj), +) +Dispatcher.register( + operator.add, + ("ListVariable", "ListVariable"), + lambda var, other: var.concat(other), +) +Dispatcher.register( + operator.add, + ("TupleVariable", "TupleVariable"), + lambda var, other: var.concat(other), +) +Dispatcher.register( + operator.mul, + ("ListVariable | TupleVariable", "ConstantVariable"), + lambda var, other: var.repeat(other), +) + +# getattr +Dispatcher.register( + getattr, + ("VariableBase", "ConstantVariable", optional("VariableBase")), + lambda var, name, default=None: var.getattr( + add_guard(name).get_py_value(), default + ), +) + +# hasattr +Dispatcher.register( + hasattr, + ("VariableBase", "ConstantVariable"), + lambda var, name: var.hasattr(add_guard(name).get_py_value()), +) + +Dispatcher.register( + delattr, + ("VariableBase", "VariableBase"), + lambda var, name: var.delattr(add_guard(name).get_py_value()), +) + +Dispatcher.register( + setattr, + ("VariableBase", "VariableBase", "VariableBase"), + lambda var, name, value: var.setattr(add_guard(name).get_py_value(), value), +) + +# len +Dispatcher.register( + len, + ("ContainerVariable | ContainerLayerVariable",), + lambda var: var.len(), +) + +# range +# stop +Dispatcher.register( + range, + ("ConstantVariable",), + lambda stop: RangeVariable( + range(stop.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([stop]), + ), +) + +# start, stop +Dispatcher.register( + range, + ("ConstantVariable", "ConstantVariable"), + lambda start, stop: RangeVariable( + range(start.get_py_value(), stop.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([start, stop]), + ), +) +# start, stop, step +Dispatcher.register( + range, + ("ConstantVariable", "ConstantVariable", "ConstantVariable"), + lambda start, stop, step: RangeVariable( + range(start.get_py_value(), stop.get_py_value(), step.get_py_value()), + graph=stop.graph, + tracker=DummyTracker([start, stop, step]), + ), +) +# TODO(zmh): Modify +# enumerate +Dispatcher.register( + enumerate, + ("VariableBase",), + lambda var: EnumerateVariable.from_iterator( + var, graph=var.graph, tracker=DummyTracker([var]) + ), +) + + +# map +Dispatcher.register( + map, + ( + "CallableVariable", + "VariableBase", + ), + lambda fn, var: MapVariable.from_iterator( + fn, var, graph=var.graph, tracker=DummyTracker([var]) + ), +) + + +# reversed +@Dispatcher.register_decorator(reversed) +def dispatch_reversed(var: ContainerVariable): + from .tracker import DanglingTracker + from .variables import BuiltinVariable, SequenceIterVariable + + length_var = BuiltinVariable(len, var.graph, DanglingTracker())(var) + assert isinstance(length_var, ConstantVariable) + getitem = BuiltinVariable(operator.getitem, var.graph, DanglingTracker()) + out = reversed([getitem(var, i) for i in range(length_var.get_py_value())]) + out_var = ListVariable( + list(out), graph=var.graph, tracker=DummyTracker([var]) + ) + return SequenceIterVariable( + out_var, + graph=var.graph, + tracker=DummyTracker([var]), + ) + + +# isinstance +Dispatcher.register( + isinstance, + ("TensorVariable", "VariableBase"), + lambda left, right: ConstantVariable( + isinstance( + paddle.to_tensor(0), + right.get_py_value(allow_tensor=True), + ), + left.graph, + DummyTracker([left, right]), + ), +) + +Dispatcher.register( + isinstance, + ("VariableBase", "VariableBase"), + lambda left, right: ConstantVariable( + isinstance( + left.get_py_value(allow_tensor=True), + right.get_py_value(allow_tensor=True), + ), + left.graph, + DummyTracker([left, right]), + ), +) + +# bool +Dispatcher.register( + bool, + ("ContainerVariable",), + lambda var: var.bool(), +) +Dispatcher.register( + operator.truth, + ("ConstantVariable",), + lambda var: var.bool(), +) + +# str +Dispatcher.register( + str, + ("ConstantVariable",), + lambda var: var.str(), +) + + +@Dispatcher.register_decorator(str.format) +def str_format(var: ConstantVariable, *args: ConstantVariable): + return var.format(*args) + + +Dispatcher.register( + str.lower, + ("ConstantVariable",), + lambda var: var.lower(), +) + + +@Dispatcher.register_decorator(str.startswith) +def str_startswith(var: ConstantVariable, substr: ConstantVariable, beg: ConstantVariable = None, end: ConstantVariable = None): # type: ignore + value = var.get_py_value() + if end is None: + end = ConstantVariable(len(value), var.graph, DanglingTracker()) + if beg is None: + beg = ConstantVariable(0, var.graph, DanglingTracker()) + + res = value.startswith( + substr.get_py_value(), beg.get_py_value(), end.get_py_value() + ) + return ConstantVariable( + res, var.graph, DummyTracker([var, substr, beg, end]) + ) + + +@Dispatcher.register_decorator(str.endswith) +def str_endswith(var: ConstantVariable, substr: ConstantVariable, beg: ConstantVariable = None, end: ConstantVariable = None): # type: ignore + value = var.get_py_value() + if end is None: + end = ConstantVariable(len(value), var.graph, DanglingTracker()) + if beg is None: + beg = ConstantVariable(0, var.graph, DanglingTracker()) + + res = value.endswith( + substr.get_py_value(), beg.get_py_value(), end.get_py_value() + ) + return ConstantVariable( + res, var.graph, DummyTracker([var, substr, beg, end]) + ) + + +# getitem +# TODO: Should pass its Variable into the getitem and perform operations such as getting value in the getitem. like this:https://github.com/PaddlePaddle/PaddleSOT/pull/198#discussion_r1241110949 +Dispatcher.register( + operator.getitem, + ( + "TensorVariable", + "Any", + ), + lambda var, key: var.getitem( + VariableFactory.from_value( + key, graph=var.graph, tracker=ConstTracker(key) + ) + ), +) + +Dispatcher.register( + operator.getitem, + ( + "VariableBase", + "int | str", + ), + lambda var, key: var.getitem( + VariableFactory.from_value( + key, graph=var.graph, tracker=ConstTracker(key) + ) + ), +) + +Dispatcher.register( + operator.getitem, + ( + "VariableBase", + "ConstantVariable | SliceVariable", + ), + lambda var, key: var.getitem(key), +) + +# setitem +Dispatcher.register( + operator.setitem, + ( + "VariableBase", + "int | str | ConstantVariable | TensorVariable", + "int | str | ConstantVariable | TensorVariable", + ), + lambda var, key, value: var.setitem(key.get_py_value(), value), +) + +# delitem +Dispatcher.register( + operator.delitem, + ( + "VariableBase", + "int | str | TensorVariable", + ), + lambda var, key: var.delitem(key), +) +Dispatcher.register( + operator.delitem, + ( + "VariableBase", + "ConstantVariable", + ), + lambda var, key: var.delitem(key.get_py_value()), +) + + +# TensorVariable +Dispatcher.register( + paddle.is_tensor, + ("TensorVariable",), + lambda var: var.is_tensor(), +) +Dispatcher.register( + paddle.is_complex, + ("TensorVariable",), + lambda var: var.is_complex(), +) +Dispatcher.register( + paddle.is_integer, + ("TensorVariable",), + lambda var: var.is_integer(), +) +Dispatcher.register( + paddle.is_floating_point, + ("TensorVariable",), + lambda var: var.is_floating_point(), +) +Dispatcher.register( + paddle.rank, + ("TensorVariable",), + lambda var: var.ndim, +) + +Dispatcher.register( + operator.is_, + ("TensorVariable", "TensorVariable"), + lambda var, other: ConstantVariable( + var.get_symbol() == other.get_symbol(), + var.graph, + tracker=DummyTracker([var, other]), + ), +) + +Dispatcher.register( + operator.is_, + ("TensorVariable", "VariableBase"), + lambda var, other: ConstantVariable( + False, + var.graph, + tracker=DummyTracker([var, other]), + ), +) + +Dispatcher.register( + operator.is_, + ("VariableBase", "TensorVariable"), + lambda var, other: ConstantVariable( + False, + var.graph, + tracker=DummyTracker([var, other]), + ), +) + +# VariableBase +Dispatcher.register( + operator.is_, + ("VariableBase", "VariableBase"), + lambda var, other: ConstantVariable( + var.get_py_value() is other.get_py_value(), + var.graph, + tracker=DummyTracker([var, other]), + ), +) + + +@Dispatcher.register_decorator(operator.is_not) +def is_not_func(var: VariableBase, other: VariableBase): + handler = Dispatcher.dispatch(operator.is_, var, other) + if handler is None: + raise FallbackError( + f"Not found implementation operator.is for {var} and {other}." + ) + return handler(var, other).bool_not() + + +# is None +Dispatcher.register( + operator_is_none, + ("VariableBase",), + lambda var: BuiltinVariable(operator.is_, var.graph, DanglingTracker())( + var, ConstantVariable.wrap_literal(None, var.graph) + ), +) + +# is not None +Dispatcher.register( + operator_is_not_none, + ("VariableBase",), + lambda var: BuiltinVariable(operator.is_not, var.graph, DanglingTracker())( + var, ConstantVariable.wrap_literal(None, var.graph) + ), +) + + +# NOTE(SigureMo): Don't directly capture free var inside for-loop, use partial instead. +# ```python +# lambdas = [] +# for i in range(10): +# lambdas.append(lambda: i) +# for fn in lambdas: +# print(fn()) # result is 9, 9, 9, 9, 9, 9, 9, 9, 9, 9 +# ``` +# Rewrite by partial: +# ```python +# lambdas = [] +# for i in range(10): +# lambdas.append(partial(lambda i: i, i)) +# for fn in lambdas: +# print(fn()) # result is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 +# ``` + +# Constant +for unary_fn in UNARY_OPS: + for magic_method in magic_method_builtin_dispatch(unary_fn): + Dispatcher.register( + unary_fn, + ("ConstantVariable",), + partial( + lambda fn, var: VariableFactory.from_value( + fn(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), + unary_fn, + ), + ) +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + Dispatcher.register( + binary_fn, + ("ConstantVariable", "ConstantVariable"), + partial( + lambda fn, var, other: VariableFactory.from_value( + fn(var.get_py_value(), other.get_py_value()), + var.graph, + tracker=DummyTracker([var, other]), + ), + binary_fn, + ), + ) +# Tensor +fallback_tensor_unary_method = { + int, + bool, + operator.truth, +} + +Dispatcher.register(tensor_numel, ("TensorVariable",), lambda x: x.numel()) + +for unary_fn in UNARY_OPS: + if unary_fn in fallback_tensor_unary_method: + Dispatcher.register( + unary_fn, + ("TensorVariable",), + raise_break_graph_fn, + ) + continue + + if unary_fn is len: + Dispatcher.register( + unary_fn, + ("TensorVariable",), + lambda x: x.len(), + ) + continue + + for magic_method in magic_method_builtin_dispatch(unary_fn): + Dispatcher.register( + unary_fn, + ("TensorVariable",), + partial( + lambda magic_name, var: var.graph.call_tensor_method( + magic_name, var + ), + magic_method.name, + ), + ) +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + # skip all inplace magic method name, we will dispatch it to non-inplace + # magic methods + if magic_method.is_inplace: + continue + + if not magic_method.is_reverse: + Dispatcher.register( + binary_fn, + ( + "TensorVariable", + "TensorVariable | ConstantVariable | NumpyVariable", + ), + partial( + lambda magic_name, var, other: var.graph.call_tensor_method( + magic_name, var, other + ), + magic_method.name, + ), + ) + else: + # skip __mod__ for str and TensorVariable + if magic_method.name == "__rmod__": + + @Dispatcher.register_decorator(operator.mod) + def tensor_mod_dispatcher( + var: ConstantVariable, other: TensorVariable + ): + if var.get_py_type() is str: + raise BreakGraphError( + "(ConstantVariable % TensorVariable) raise a callback. " + ) + raise FallbackError("Tensor doesn't support __rmod__") + + else: + Dispatcher.register( + binary_fn, + ( + "ConstantVariable | NumpyVariable", + "TensorVariable", + ), + partial( + lambda reverse_magic_name, var, other: other.graph.call_tensor_method( + reverse_magic_name, other, var + ), + magic_method.name, + ), + ) + +# Register dispatch for NumpyVariable: fallback ! +for unary_fn in UNARY_OPS: + if unary_fn in [bool]: + continue + for magic_method in magic_method_builtin_dispatch(unary_fn): + + @Dispatcher.register_decorator(unary_fn) + def numpy_unary_dispatcher(var: NumpyVariable): + raise FallbackError('Numpy operator need fallback to dygraph') + + +Dispatcher.register( + operator.eq, + ("NumpyVariable", "ConstantVariable | NumpyVariable"), + lambda left, right: constant_numpy_equal(right, left), +) + + +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + + @Dispatcher.register_decorator(binary_fn) + def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable): + raise FallbackError('Numpy operator need fallback to dygraph') + + +# Register dispatch for DataVariable: directy call and return a wrapped variable. +def data_variable_binary_dispatcher(var, other, operator): + return VariableFactory.from_value( + operator(var.get_py_value(), other.get_py_value()), + var.graph, + DummyTracker([var, other]), + ) + + +for binary_fn in BINARY_OPS: + for magic_method in magic_method_builtin_dispatch(binary_fn): + Dispatcher.register( + binary_fn, + ("DataVariable", "Any"), + partial(data_variable_binary_dispatcher, operator=binary_fn), + ) + Dispatcher.register( + binary_fn, + ("Any", "DataVariable"), + partial(data_variable_binary_dispatcher, operator=binary_fn), + ) + +for unary_fn in UNARY_OPS: + for magic_method in magic_method_builtin_dispatch(unary_fn): + + def data_variable_unary_dispatcher(var: DataVariable, fn): + return VariableFactory.from_value( + fn(var.get_py_value()), + var.graph, + DummyTracker([var]), + ) + + Dispatcher.register( + unary_fn, + ("DataVariable",), + partial(data_variable_unary_dispatcher, fn=unary_fn), + ) + + +Dispatcher.register( + math.ceil, + ("ConstantVariable",), + lambda var: ConstantVariable( + math.ceil(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), +) + +Dispatcher.register( + math.floor, + ("ConstantVariable",), + lambda var: ConstantVariable( + math.floor(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), +) + +Dispatcher.register( + ord, + ("ConstantVariable",), + lambda var: var.ord(), +) + +Dispatcher.register( + chr, + ("ConstantVariable",), + lambda var: var.chr(), +) + + +# pow +# base ** exp % mod +@Dispatcher.register_decorator(pow) +def dispatch_pow(base: VariableBase, exp: VariableBase, mod: VariableBase = None): # type: ignore + graph = base.graph + result = BuiltinVariable(operator.pow, graph, DanglingTracker())(base, exp) + if exp is not None: + result = BuiltinVariable(operator.mod, graph, DanglingTracker())( + result, mod + ) + return result + + +Dispatcher.register( + math.pow, + ("ConstantVariable", "ConstantVariable"), + lambda var1, var2: ConstantVariable( + math.pow(var1.get_py_value(), var2.get_py_value()), + var1.graph, + tracker=DummyTracker([var1, var2]), + ), +) + + +@Dispatcher.register_decorator(sum) +def dispatch_sum(var: ContainerVariable | TensorVariable, start: VariableBase = None): # type: ignore + if start is None: + start = ConstantVariable.wrap_literal(0, var.graph) + elements = [ + var.getitem(ConstantVariable.wrap_literal(i, var.graph)) + for i in range(len(var)) + ] + result = reduce( + BuiltinVariable(operator.add, var.graph, DanglingTracker()), + elements, + start, + ) + return result + + +Dispatcher.register( + max, + ("ListVariable",), + lambda var: var.max(), +) + +Dispatcher.register( + min, + ("ListVariable",), + lambda var: var.min(), +) + +Dispatcher.register( + math.sqrt, + ("ConstantVariable",), + lambda var: ConstantVariable( + math.sqrt(var.get_py_value()), + var.graph, + tracker=DummyTracker([var]), + ), +) + + +def constant_numpy_equal(left, right): + numpy_ans = left.get_py_value() == right.get_py_value() + return NumpyVariable( + numpy_ans, + left.graph, + tracker=DummyTracker([left, right]), + ) + + +Dispatcher.register( + operator.eq, + ("ConstantVariable", "NumpyVariable"), + lambda left, right: constant_numpy_equal(left, right), +) + +Dispatcher.register( + bool, + ("NumpyVariable",), + lambda x: ConstantVariable( + bool(x.get_py_value()), + x.graph, + tracker=DummyTracker([x]), + ), +) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py new file mode 100644 index 00000000000000..e7389de5b88050 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variable_stack.py @@ -0,0 +1,216 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload + +if TYPE_CHECKING: + ValidateValueFunc = Callable[[Any], None] + + +StackDataT = TypeVar("StackDataT") + + +class VariableStack(Generic[StackDataT]): + """ + A stack class for storing variables. + + Examples: + >>> var1, var2, var3, var4 = range(1, 5) + >>> stack = VariableStack() + >>> stack.push(var1) + >>> stack.push(var3) + >>> stack.insert(1, var2) + >>> stack + [1, 2, 3] + >>> stack.pop() + 3 + >>> stack.pop_n(2) + [1, 2] + >>> stack.push(var1) + >>> stack.push(var2) + >>> stack.push(var3) + >>> stack + [1, 2, 3] + >>> stack.top + 3 + >>> stack.peek[1] + 3 + >>> stack.peek[:1] + [3] + >>> stack.peek[:2] + [2, 3] + >>> stack.peek[1] = var4 + >>> stack + [1, 2, 4] + + """ + + class VariablePeeker: + @overload + def __getitem__(self, index: int) -> StackDataT: + ... + + @overload + def __getitem__(self, index: slice) -> list[StackDataT]: + ... + + @overload + def __call__(self, index: int = 1) -> StackDataT: + ... + + @overload + def __call__(self, index: slice) -> list[StackDataT]: + ... + + def __init__( + self, data: list[StackDataT], validate_value_func: ValidateValueFunc + ): + self._data = data + self.validate_value_func = validate_value_func + + def __getitem__( + self, index: int | slice + ) -> StackDataT | list[StackDataT]: + if isinstance(index, int): + assert 0 < index <= len(self._data) + return self._data[-index] + if isinstance(index, slice): + assert ( + index.start is None and index.step is None + ), "slice which has start or step not supported" + assert 0 < index.stop <= len(self._data) + return self._data[-index.stop :] + raise NotImplementedError(f"index type {type(index)} not supported") + + def __setitem__(self, index: int, value: Any): + assert isinstance( + index, int + ), f"index type {type(index)} not supported" + assert ( + 0 < index <= len(self._data) + ), f"index should be in [1, {len(self._data)}], but get {index}" + self.validate_value_func(value) + self._data[-index] = value + + def __call__( + self, index: int | slice = 1 + ) -> StackDataT | list[StackDataT]: + return self[index] + + def __init__( + self, + data: list[StackDataT] | None = None, + *, + validate_value_func: ValidateValueFunc | None = None, + ): + if data is None: + data = [] + else: + data = data.copy() + self.validate_value_func = ( + (lambda _: None) + if validate_value_func is None + else validate_value_func + ) + self._data = data + self._peeker = VariableStack.VariablePeeker( + self._data, self.validate_value_func + ) + + def copy(self): + return VariableStack( + self._data, validate_value_func=self.validate_value_func + ) + + def push(self, val: StackDataT): + """ + Pushes a variable onto the stack. + + Args: + val: The variable to be pushed. + + """ + self.validate_value_func(val) + self._data.append(val) + + def insert(self, index: int, val: StackDataT): + """ + Inserts a variable onto the stack. + + Args: + index: The index at which the variable is to be inserted, the top of the stack is at index 0. + val: The variable to be inserted. + + """ + assert ( + 0 <= index <= len(self) + ), f"index should be in [0, {len(self)}], but get {index}" + self.validate_value_func(val) + self._data.insert(len(self) - index, val) + + def pop(self) -> StackDataT: + """ + Pops the top value from the stack. + + Returns: + The popped value. + + """ + assert len(self) > 0, "stack is empty" + return self._data.pop() + + def pop_n(self, n: int) -> list[StackDataT]: + """ + Pops the top n values from the stack. + + Args: + n: The number of values to pop. + + Returns: + A list of the popped values. + + """ + assert ( + len(self) >= n >= 0 + ), f"n should be in [0, {len(self)}], but get {n}" + if n == 0: + return [] + retval = self._data[-n:] + self._data[-n:] = [] + return retval + + @property + def peek(self) -> VariablePeeker: + return self._peeker + + @property + def top(self) -> StackDataT: + assert len(self) > 0, "stack is empty" + return self.peek[1] + + @top.setter + def top(self, value): + assert len(self) > 0, "stack is empty" + self.peek[1] = value + + def __iter__(self): + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __repr__(self) -> str: + return str(self._data) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py b/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py new file mode 100644 index 00000000000000..9611734ffffcdd --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import ( # noqa: F401 + ConstTypes, + VariableBase, + VariableFactory, + find_traceable_vars, + map_variables, +) +from .basic import ( # noqa: F401 + CellVariable, + ConstantVariable, + DataVariable, + DygraphTracerVariable, + FunctionGlobalVariable, + GlobalVariable, + ModuleVariable, + NullVariable, + NumpyVariable, + ObjectVariable, + SliceVariable, + TensorVariable, +) +from .callable import ( # noqa: F401 + BuiltinVariable, + CallableVariable, + ClassVariable, + ContainerLayerVariable, + FunctionVariable, + LayerVariable, + MethodVariable, + PaddleApiVariable, + PaddleLayerVariable, + UserDefinedFunctionVariable, + UserDefinedGeneratorVariable, + UserDefinedLayerVariable, +) +from .container import ( # noqa: F401 + ContainerVariable, + DictVariable, + ListVariable, + RangeVariable, + TupleVariable, +) +from .iter import ( # noqa: F401 + EnumerateVariable, + IterVariable, + MapVariable, + SequenceIterVariable, + UserDefinedIterVariable, +) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/base.py b/python/paddle/jit/sot/opcode_translator/executor/variables/base.py new file mode 100644 index 00000000000000..17cb99aeef516a --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/base.py @@ -0,0 +1,618 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import operator +from functools import cached_property +from queue import Queue +from typing import TYPE_CHECKING, Any, Callable, Optional + +import paddle + +from ....profiler import event_register +from ....utils import NameGenerator, get_unbound_method, log +from ....utils.exceptions import FallbackError, HasNoAttributeError +from ..dispatcher import Dispatcher +from ..guard import StringifyExpression, check_guard, union_free_vars +from ..mutable_data import MutableDictLikeData +from ..pycode_generator import PyCodeGen +from ..tracker import ( + DummyTracker, + GetAttrTracker, + GetItemTracker, + GetIterTracker, + Tracker, +) + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + # Each variable object should implement a method called `from_value`, + # which should adhere to the FromValueFunc signature. + FromValueFunc = Callable[ + [Any, FunctionGraph, Tracker], Optional["VariableBase"] + ] + + +ConstTypes = (int, float, str, bool, type(None)) + + +@event_register("find_traceable_vars") +def find_traceable_vars( + root_vars: list[VariableBase], +) -> list[VariableBase]: + """ + This function is used to find all traceable variables in the given list of variables. + + Args: + root_vars (list[VariableBase]): A list of root variables from which the ordering starts. + + Returns: + list[VariableBase]: A list of variables that are traceable. + """ + results: list[VariableBase] = [] + visited: set[VariableBase] = set() + queue: Queue[VariableBase] = Queue() + + for root in root_vars: + queue.put(root) + + while not queue.empty(): + var = queue.get() + if var in visited: + continue + + visited.add(var) + if var.tracker.need_guard(): + results.append(var) + continue + + # Pruning traceable variable, if the variable is traceable, we don't need to + # trace its inputs. + inputs = var.get_inputs() + + for var in inputs: + if var not in visited and var not in queue.queue: + queue.put(var) + + return results + + +def map_variables(map_func, variables: list[VariableBase]): + """ + This function maps the given map_func to the given list of variables in a recursive manner. + Args: + map_func (Callable[[VariableBase], Any]): The function to be mapped to each variable. + variables (list[VariableBase]): A list of variables to which the map_func is to be applied. + + Returns: + tuple: The result of applying the map_func to the variables. + """ + + def _map_variable(variable: VariableBase | object): + from .basic import SliceVariable + from .container import ContainerVariable + + if isinstance(variable, ContainerVariable): + return paddle.utils.map_structure( + _map_variable, variable.get_wrapped_items() + ) + + if isinstance(variable, SliceVariable): + return slice( + map_func(variable.getattr("start")), + map_func(variable.getattr("stop")), + map_func(variable.getattr("step")), + ) + + return map_func(variable) + + return paddle.utils.map_structure(_map_variable, variables) + + +class VariableFactory: + """ + A factory class for creating variables from arbitrary values. + + This class provides a set of registration and factory methods for creating variables + of different types based on the type of the input value. + + """ + + registered_funcs: dict[str, list[str]] = {"default": []} + mapping_str_func: dict[str, FromValueFunc] = {} + + @staticmethod + def default_from_value(value, graph, tracker): + """ + A default factory function that creates an ObjectVariable from the given value. + + Args: + value: The input value. + graph: The FunctionGraph object that this variable is associated with. + tracker: The Tracker object that tracks the information of this variable. + + Returns: + ObjectVariable: A new ObjectVariable representing the input value. + """ + from .basic import ObjectVariable + + return ObjectVariable(value, graph, tracker) + + @staticmethod + def register_from_value(*, successor: str | None = None): + """ + A decorator function that registers a function for creating a Variable from a value. + + Args: + successor (str | None, optional): The name of the successor function that will be called after this function when creating a Variable. If None, the function is added to a default list of functions. + + Returns: + The _register_from_value decorator function, which takes the function to be registered as an argument. + """ + registered_funcs = VariableFactory.registered_funcs + mapping_str_func = VariableFactory.mapping_str_func + + def _register_from_value(func: FromValueFunc): + """ + Function to register a function for creating a Variable from a value + """ + # Get the name of the function + name = func.__qualname__.split(".")[0] + # Map the name of the function to the function + mapping_str_func[name] = func + if successor is None: + registered_funcs["default"].append( + name + ) # If successor is None, add the function to the "default" list + elif successor not in registered_funcs.keys(): + registered_funcs[successor] = [ + name + ] # If the successor is not in the registered_funcs dictionary, set the value to a list containing only name + else: + registered_funcs[successor].append( + name + ) # If the successor is in the registered_funcs dictionary, append name to the existing list of functions for that successor + + log( + 4, VariableFactory.registered_funcs + ) # Print the registered_funcs dictionary if the logging level is at least 4 + return _register_from_value + + @staticmethod + def from_value( + value: Any, + graph: FunctionGraph, + tracker: Tracker, + *, + debug_name: str | None = None, + ) -> VariableBase: + """ + Create a new variable object from the given value. + + This method searches through the registered from_value functions to find one + that can create a variable object from the given value. If no matching function + is found, the default_from_value function is used. + + Args: + value (Any): The input value. + graph (FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker (Tracker): The Tracker object that tracks the information of this variable. + debug_name (str | None): An optional debug name for the variable. + + Returns: + VariableBase: A new variable object representing the input value. + """ + registered_funcs = VariableFactory.registered_funcs + + def _find_var(key: str = "default") -> VariableBase | None: + for name in registered_funcs[key]: + if name in registered_funcs.keys(): + # If the function name is a key in the registered_funcs dictionary, recursively find a Variable using that function + var = _find_var(name) + if var is not None: + return var + # Get the function corresponding to the name from the mapping_str_func dictionary + func = VariableFactory.mapping_str_func[name] + var = func( + value, graph, tracker + ) # Call the function to create a Variable from the value + if var is not None: + return var + + var = _find_var() + if var is None: + var = VariableFactory.default_from_value( + value, graph, tracker + ) # If a Variable could not be found using the registered functions, use the default function to create a new Variable + var.debug_name = debug_name + return var + + +class VariableBase: + """ + VariableBase is a basic concept and each symbols in VM stack is regarded as + an Variable Object in symblic tracing process. + + There are two key data structures during Python runtime: + PyFrameObject, which provides the instance for function logical lock usage, + and PyCodeObject, which provides the bytecode for the corresponding function. + With these data, the Python virtual machine executes the bytecode sequentially on a stack to complete function logic. + + Args: + tracker(Tracker): The Tracker object that tracks the information of this variable. + + Note: + We should push an object of a subclass of VariableBase instead of an object of VariableBase onto the VM stack. + It serves as an abstract class and should not be instantiated directly. + """ + + tracker: Tracker # An attribute to store the Tracker object associated with the variable + value: Any + name_generator = NameGenerator( + "object_" + ) # A class-level attribute to generate names for new variables + mutable_attrs = [] + + def __init__(self, graph: FunctionGraph, tracker: Tracker): + self.graph = graph + self.tracker = tracker + self.id = VariableBase.name_generator.next() + self._debug_name: str | None = None + + @property + def main_info(self) -> dict[str, Any]: + """ + Property method to return a dictionary of main information about the variable + + Returns: + main_info: Main information of the variable. + """ + return {} + + @property + def debug_info(self) -> dict[str, Any]: + """ + Property method to return a dictionary of debug information about the variable + """ + return { + "debug_name": self.debug_name, + "id": self.id, + } + + @property + def debug_name(self) -> str: + """ + Generate a debug_name for each variable. + + Returns: + _debug_name: the name of variable. + """ + if self._debug_name is not None: + # Return the self._debug_name cache if it is not None. + return self._debug_name + inputs = self.tracker.inputs + if isinstance(self.tracker, GetItemTracker): + self._debug_name = ( + f"{self.tracker.container.debug_name}[{self.tracker.key}]" + ) + elif isinstance(self.tracker, GetAttrTracker): + self._debug_name = ( + f"{self.tracker.obj.debug_name}.{self.tracker.attr}" + ) + elif len(inputs) == 0: + self._debug_name = "tmp_var" + else: # len(inputs) >= 0 + for input in inputs: + assert input is not None + self._debug_name = "tmp_var_" + "_".join( + input.debug_name for input in inputs + ) + return self._debug_name + + @debug_name.setter + def debug_name(self, name): + self._debug_name = name + + def __hash__(self): + return hash(self.id) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + """ + Create a StringifyExpression object that represents a guard expression for this variable. + + Returns: + StringifyExpression: An object that contains the guard expression and the free variables used in the expression. + """ + + # Get a ValueTracer object from the Tracker object associated with the variable + frame_value_tracer = self.tracker.trace_value_from_frame() + + return [ + StringifyExpression( + f"id(type({{}})) == {id(self.get_py_type())}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + StringifyExpression( + f"{{}} == {self.get_py_value()!r}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + ] + + def get_py_value(self, allow_tensor=False) -> Any: + """ + Abstract method to get the value of the variable + """ + raise NotImplementedError() + + def get_py_type(self): + """ + Method to get the type of the variable's value + """ + return type(self.get_py_value()) + + def is_none(self) -> bool: + """ + Method to check if the variable's value is None + """ + return self.get_py_value() is None + + def reconstruct( + self, + codegen: PyCodeGen, + *, + use_tracker: bool = True, + add_to_global_guarded_vars: bool = True, + ): + if self.tracker.is_traceable() and use_tracker: + self.tracker.gen_instructions(codegen) + else: + if add_to_global_guarded_vars: + self.graph.add_global_guarded_variable(self) + self._reconstruct(codegen) + + def _reconstruct(self, codegen: PyCodeGen): + """ + Abstract method to construct an opcode and append it into codegen.instructions + """ + raise FallbackError( + f'{self.__class__.__name__} does not implement "_reconstruct" method' + ) + + def flatten_items(self) -> list[VariableBase]: + """ + Recursively flatten the items in this container variable to a list of Variable objects. + + Returns: + list[VariableBase]: Flattened items of a container variable. + """ + from .container import ContainerVariable + + if not isinstance(self, ContainerVariable): + return [self] + flattened_items = [] + for item in self.get_items(): + flattened_items.extend(item.flatten_items()) + return flattened_items + + def get_inputs(self) -> list[VariableBase]: + """ + This method is used to get the inputs for the current variable. + + Returns: + list[VariableBase]: Inputs for the current variable. + """ + return self.tracker.inputs + + def get_traceable_inputs(self) -> list[VariableBase]: + """ + This method is used to get the traceable inputs for the current variable. + + Returns: + list[VariableBase]: Traceable inputs for the current variable. + """ + return list( + filter(lambda x: x.tracker.is_traceable(), self.tracker.inputs) + ) + + def call_function(self, /, *args, **kwargs): + pass + + @cached_property + def attr_proxy(self): + return self.graph.side_effects.get_proxy( + MutableDictLikeData, self.get_py_value(), self.attr_proxy_getter + ) + + def attr_proxy_getter(self, proxy: MutableDictLikeData, name: str): + if not hasattr(proxy.original_data, name): # can't true. + return MutableDictLikeData.Empty() + + attr = getattr(proxy.original_data, name) + if inspect.ismethod(attr) or ( + hasattr(attr, "__self__") + and inspect.ismethoddescriptor( + getattr(attr.__self__.__class__, name, None) + ) + ): + from .callable import MethodVariable + + fn = None + if inspect.ismethoddescriptor( + getattr(attr.__self__.__class__, name, None) + ): + class_var = VariableFactory.from_value( + self.get_py_type(), + self.graph, + GetAttrTracker(self, "__class__"), + ) + fn = VariableFactory.from_value( + getattr(attr.__self__.__class__, name), + self.graph, + GetAttrTracker(class_var, name), + ) + return MethodVariable.wrap_method( + value=attr, + instance=self, + fn=fn, + graph=self.graph, + tracker=GetAttrTracker(self, name), + method_name=name, + ) + + return VariableFactory.from_value( + attr, self.graph, tracker=GetAttrTracker(self, name) + ) + + def hasattr(self, name: str): + from .basic import ConstantVariable + + try: + self.getattr(name) + return ConstantVariable( + True, graph=self.graph, tracker=DummyTracker([self]) + ) + except HasNoAttributeError: + # NOTE(SigureMo): Only the HasNoAttributeError is raised, we can + # ensure that the attribute does not exist. Otherwise, we should + # raise the error. + return ConstantVariable( + False, graph=self.graph, tracker=DummyTracker([self]) + ) + + def getattr(self, name: str, default=None): + result = self.attr_proxy.get(name) + if isinstance(result, MutableDictLikeData.Empty): + if default is not None: + assert isinstance(default, VariableBase) + return default + raise HasNoAttributeError( + f"{self.__class__.__name__} {self} has no attribute {name}" + ) + return result + + def setattr(self, key: str, value): + from .basic import ConstantVariable + + self.attr_proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def delattr(self, key: str): + from .basic import ConstantVariable + + self.attr_proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def __setitem__(self, key, value): + return self.setitem(key, value) + + def setitem(self, key, value): + raise FallbackError(f"{self} is not support setitem.") + + def __repr__(self): + info = {**self.main_info, **self.debug_info} + info_str = ", ".join([f"{value}" for value in info.values()]) + return f"{self.__class__.__name__}({info_str})" + + def __str__(self): + return self.__repr__() + + def __getitem__(self, idx): + return Dispatcher.call(operator.getitem, self, idx) + + def getitem(self, item): + class_var = VariableFactory.from_value( + self.get_py_value().__class__, + self.graph, + GetAttrTracker(self, '__class__'), + ) + fn_var = VariableFactory.from_value( + get_unbound_method(self.get_py_value(), '__getitem__'), + self.graph, + GetAttrTracker(class_var, '__getitem__'), + ) + self.graph.add_global_guarded_variable(item) + item = item.get_py_value() + output = fn_var(self, item) + return output + + def __call__(self, /, *args, **kwargs): + """ + Call the object represented by this variable with the given arguments. + + Args: + *args: Positional arguments to pass to the object's __call__ method. + **kwargs: Keyword arguments to pass to the object's __call__ method. + + Returns: + VariableBase: A new variable representing the result of calling the object's __call__ method. + """ + from .callable import BuiltinVariable, UserDefinedFunctionVariable + + class_var = VariableFactory.from_value( + self.get_py_value().__class__, + self.graph, + GetAttrTracker(self, '__class__'), + ) + assert class_var is not None + # if __call__ is a method, we should add self to arguments. + if inspect.ismethod(self.get_py_value().__call__): + args = (self,) + args + unbound_method = get_unbound_method(self.get_py_value(), '__call__') + if hasattr(unbound_method, "__code__"): + fn_var = UserDefinedFunctionVariable( + unbound_method, + self.graph, + GetAttrTracker(class_var, '__call__'), + ) + else: + fn_var = BuiltinVariable( + self.value, + self.graph, + GetAttrTracker(class_var, '__call__'), + ) + output = fn_var(*args, **kwargs) + return output + + def get_iter(self): + from .iter import UserDefinedIterVariable + + return UserDefinedIterVariable(self, self.graph, GetIterTracker(self)) + + @VariableFactory.register_from_value() + def from_value( + value: Any, + graph: FunctionGraph | None, + tracker: Tracker, + ) -> VariableBase | None: + """ + Create a new variable from a given value, or return None if the value cannot be converted to a variable. + Args: + value (Any): The value to create a variable from. + graph (FunctionGraph | None): The graph in which the variable will be used. + tracker (Tracker): The variable tracker to put the new variable in if created. + + Returns: + VariableBase | None: A new variable if one can be created from the given value, or None if the value cannot be converted to a variable. + """ + if isinstance(value, VariableBase): + return value + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py new file mode 100644 index 00000000000000..ba0a7f51c91a03 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/basic.py @@ -0,0 +1,888 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +import types +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Any + +import numpy as np + +import paddle + +from ....infer_meta import MetaInfo +from ....symbolic.statement_ir import Symbol +from ....utils import ( + BreakGraphError, + FallbackError, + NameGenerator, + paddle_tensor_methods, +) +from ....utils.exceptions import HasNoAttributeError, InnerError +from ..dispatch_functions import tensor_numel +from ..guard import ( + StringifyExpression, + check_guard, + object_equal_stringify_guard, + union_free_vars, +) +from ..mutable_data import MutableDictLikeData +from ..pycode_generator import PyCodeGen +from ..tracker import ( + ConstTracker, + DanglingTracker, + DummyTracker, + GetAttrTracker, + GetIterTracker, + GlobalTracker, + Tracker, +) +from .base import ConstTypes, VariableBase, VariableFactory + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + from .callable import FunctionVariable + + +FP_DTYPE_ABBRS = { + paddle.bfloat16: 'bfloat16', + paddle.float64: 'float64', + paddle.float32: 'float32', + paddle.float16: 'float16', +} + +CP_DTYPE_ABBRS = { + paddle.complex64: 'complex64', + paddle.complex128: 'complex128', +} + +INT_DTYPE_ABBRS = { + paddle.int8: 'int8', + paddle.int16: 'int16', + paddle.int32: 'int32', + paddle.int64: 'int64', + paddle.uint8: 'uint8', +} + +DTYPE_ABBRS = { + **FP_DTYPE_ABBRS, + **CP_DTYPE_ABBRS, + **INT_DTYPE_ABBRS, + paddle.bool: 'bool', +} + + +class ConstantVariable(VariableBase): + """ + ConstantVariable is a subclass of VariableBase used to wrap a Variable of the const type. + + Args: + value(Any): The value to be wrapped. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + value: Any, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.value = value + + def get_py_value(self, allow_tensor=False): + return self.value + + @property + def debug_name(self) -> str: + return f"{self.value}" + + @debug_name.setter + def debug_name(self, name): + pass + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_const(self.value) + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def __bool__(self) -> bool: + return bool(self.value) + + def bool(self): + return ConstantVariable(bool(self), self.graph, DummyTracker([self])) + + def bool_not(self): + assert isinstance( + self.get_py_value(), bool + ), "Bool_not can only be applied to a bool variable." + return ConstantVariable( + not bool(self.get_py_value()), self.graph, DummyTracker([self]) + ) + + def str(self): + return ConstantVariable( + str(self.value), self.graph, DummyTracker([self]) + ) + + def format(self, *args): + return ConstantVariable( + str(self.value).format(*[str(a.value) for a in args]), + self.graph, + DummyTracker([self, *args]), + ) + + def lower(self): + return ConstantVariable( + str(self.value).lower(), + self.graph, + DummyTracker([self]), + ) + + def ord(self): + return ConstantVariable( + ord(self.value), + self.graph, + DummyTracker([self]), + ) + + def chr(self): + return ConstantVariable( + chr(self.value), + self.graph, + DummyTracker([self]), + ) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) in ConstTypes: + return ConstantVariable(value, graph, tracker) + return None + + @staticmethod + def wrap_literal(value: Any, graph: FunctionGraph) -> ConstantVariable: + """ + Wrap a literal value in a ConstantVariable. + + Args: + value(Any): The literal value to be wrapped. + + Returns: + ConstantVariable: A new ConstantVariable object that wraps the given value. + """ + if isinstance(value, ConstantVariable): + return value + assert isinstance( + value, ConstTypes + ), f"value: {value},type: {type(value)}" + return ConstantVariable(value, graph, ConstTracker(value)) + + +class PrintStmtVariable(VariableBase): + def __init__(self, value: Any, graph: FunctionGraph): + # TODO: graph should be not None + super().__init__(None, DanglingTracker()) + self.args, self.kwargs = value + self.graph = graph + + def _reconstruct(self, codegen: PyCodeGen): + # do we need ? may be too strict. + for var in self.args: + self.graph.add_global_guarded_variable(var) + for var in self.kwargs.values(): + self.graph.add_global_guarded_variable(var) + # currently dont' consider kwargs + codegen.gen_load_global("print", push_null=True) + for var in self.args: + var.reconstruct(codegen) + codegen.gen_call_function(len(self.args)) + codegen.gen_pop_top() + + def flatten_items(self): + return self.args + + +IMPLEMENTED_TENSOR_PROPERTIES = set() + + +def tensor_property(func): + IMPLEMENTED_TENSOR_PROPERTIES.add(func.__name__) + return property(func) + + +class DataVariable(VariableBase): + """ + A value only object. + If it's all magic method don't change the function_graph state, [tensor op, guard, side_effect] + we will call it a ValueObjectVariable, we directy call python operator on it. + """ + + def __init__( + self, + value: Any, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.value = value + + def get_py_value(self, allow_tensor=False): + return self.value + + +class TensorDtypeVariable(DataVariable): + def __init__(self, value, graph, tracker): + super().__init__(value, graph, tracker) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.tracker, GetAttrTracker) and isinstance( + self.tracker.obj, TensorVariable + ): + tensor_value_tracer = ( + self.tracker.obj.tracker.trace_value_from_frame() + ) + return [ + StringifyExpression( + f"str(MetaInfo.from_tensor({{}}).dtype) == '{str(self.value)}'", + [tensor_value_tracer], + {"MetaInfo": MetaInfo}, + ) + ] + else: + return object_equal_stringify_guard(self) + + @property + def main_info(self) -> dict[str, Any]: + return { + "dtype": self.value, + } + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, paddle.dtype): + return TensorDtypeVariable(value, graph, tracker) + + +class TensorVariable(VariableBase): + """ + TensorVariable is a subclass of VariableBase used to wrap a Variable of the tensor type. + + Args: + tensor (paddle.Tensor | MetaInfo): The tensor to be wrapped. + graph (FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker (Tracker): The Tracker object that tracks the information of this variable. + """ + + var_name_generator = NameGenerator("var_") + mutable_attrs = ["meta"] + + def __init__( + self, + tensor: paddle.Tensor | MetaInfo, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + if isinstance(tensor, paddle.Tensor): + self.value = None + self.meta = MetaInfo.from_tensor(tensor) + elif isinstance(tensor, MetaInfo): + self.value = None + self.meta = tensor + else: + raise InnerError( + "Required type(tensor) is paddle.Tensor or ProxyTensor, but received {}.".format( + type(tensor).__name__ + ) + ) + self.origin_meta = self.meta + self.var_name = TensorVariable.var_name_generator.next() + self.graph.side_effects.record_mutable_variable(self) + + def __len__(self): + if self.meta.shape[0] == -1: + raise BreakGraphError( + "length of tensor variable with first dimension == -1" + ) + return self.meta.shape[0] + + def get_py_value(self, allow_tensor=False): + if allow_tensor: + + class SotTensor: + def __init__(self, id_): + self.id = id_ + + def __eq__(self, var): + if not hasattr(var, "id"): + return False + else: + return self.id == var.id + + return SotTensor(self.id) + + raise BreakGraphError( + "Called TensorVariable.get_py_value. Should not use Tensor's value in simulating." + ) + + def get_py_type(self): + return paddle.Tensor + + def get_symbol(self) -> Symbol: + return Symbol(self.var_name) + + @property + def out_var_name(self): + return f"{self.graph.OUT_VAR_PREFIX}{self.var_name}" + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_fast(self.out_var_name) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + return [ + StringifyExpression( + f"MetaInfo.from_tensor({{}}).guard_str() == '{self.origin_meta.guard_str()}'", + [frame_value_tracer], + union_free_vars( + {"MetaInfo": MetaInfo}, + frame_value_tracer.free_vars, + ), + ) + ] + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + @property + def main_info(self) -> dict[str, Any]: + return { + "shape": self.meta.shape, + "dtype": DTYPE_ABBRS[self.meta.dtype], + "stop_gradient": self.meta.stop_gradient, + "var_name": self.var_name, + } + + def getitem(self, key): + return self.graph.call_tensor_method('__getitem__', self, key) + + def setitem(self, key, value): + self.graph.add_global_guarded_variable(value) + + key_var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ) + new_tensor = self.graph.call_paddle_api( + paddle.static.setitem, + self, + key_var, + value, + ) + + self.meta = new_tensor.meta + self.graph.add_inplace_tensors(self) + + @tensor_property + def T(self): + """ + Return a new TensorVariable object that wraps the result of calling the transpose method on the wrapped value of this TensorVariable. + """ + from .container import ListVariable + + perm = list(range(len(self.meta.shape) - 1, -1, -1)) + perm_var = ListVariable(perm, self.graph, tracker=ConstTracker(perm)) + assert perm_var is not None + out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) + return out + + @tensor_property + def ndim(self): + """ + Return a ConstantVariable object that represents the number of dimensions of the wrapped value of this TensorVariable. + """ + return ConstantVariable( + len(self.meta.shape), self.graph, DummyTracker([self]) + ) + + @tensor_property + def size(self): + """ + Return a ConstantVariable object that represents the total number of elements in the wrapped value of this TensorVariable. + """ + # TODO: maybe break graph. + if self.meta.is_dynamic_shape(): + raise BreakGraphError( + f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" + ) + elements = reduce(operator.mul, self.meta.shape, 1) + return ConstantVariable(elements, self.graph, DummyTracker([self])) + + @tensor_property + def shape(self): + if self.meta.is_dynamic_shape(): + raise BreakGraphError( + f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" + ) + from .container import ListVariable + + return ListVariable( + self.meta.shape, self.graph, tracker=DummyTracker([self]) + ) + + def numel(self): + return self.size + + def len(self): + if len(self.meta.shape) == 0: + raise InnerError("len() of a 0-D tensor is wrong") + first_dim = self.meta.shape[0] + if first_dim == -1: + raise BreakGraphError( + "Getting len() for a dynamic shape tensor causes graph break." + ) + + return ConstantVariable(first_dim, self.graph, DummyTracker([self])) + + def is_tensor(self): + return ConstantVariable(True, self.graph, DummyTracker([self])) + + def is_complex(self): + dtype = self.meta.dtype + is_cp_dtype = dtype in CP_DTYPE_ABBRS + return ConstantVariable(is_cp_dtype, self.graph, DummyTracker([self])) + + def is_integer(self): + dtype = self.meta.dtype + is_int_dtype = dtype in INT_DTYPE_ABBRS + return ConstantVariable(is_int_dtype, self.graph, DummyTracker([self])) + + def is_floating_point(self): + dtype = self.meta.dtype + is_fp_dtype = dtype in FP_DTYPE_ABBRS + return ConstantVariable(is_fp_dtype, self.graph, DummyTracker([self])) + + def getattr(self, name: str, default=None): + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + method_name_to_builtin_fn = { + "dim": paddle.rank, + "numel": tensor_numel, + "ndimension": paddle.rank, + "is_tensor": paddle.is_tensor, + "is_complex": paddle.is_complex, + "is_integer": paddle.is_integer, + "is_floating_point": paddle.is_floating_point, + } + if name in ["dtype", "type", "name", "persistable", "stop_gradient"]: + if name == "name" and self.meta.name.startswith( + "infer_meta_variable_tmp" + ): + raise BreakGraphError(f"{self.meta.name} is a middle tensor.") + return VariableFactory.from_value( + getattr(self.meta, name), + self.graph, + tracker=GetAttrTracker(self, name), + ) + elif name in IMPLEMENTED_TENSOR_PROPERTIES: + return getattr(self, name) + elif name in method_name_to_builtin_fn: + # TODO: backward, gradient + from .callable import BuiltinVariable + + builtin_fn = method_name_to_builtin_fn[name] + + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + elif name in paddle_tensor_methods: + from .callable import TensorFunctionVariable + + fn_var = TensorFunctionVariable( + name, graph=self.graph, tracker=DanglingTracker() + ) + return fn_var.bind(self, name) + else: + raise HasNoAttributeError(f"Unknown Tensor attribute: {name}") + + def setattr(self, key, val): + # support tensor variable store attr, like: + # t.stop_gradient = True + self.graph.call_tensor_method( + "__setattr__", + self, + VariableFactory().from_value(key, self.graph, ConstTracker(key)), + val, + ) + + def delattr(self, key): + raise BreakGraphError("Don't support TensorVariable delattr") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, (paddle.Tensor, MetaInfo)): + return TensorVariable(value, graph, tracker) + return None + + +class ObjectVariable(VariableBase): + """ + ObjectVariable is a subclass of VariableBase used to wrap a Variable of the object type. + + Args: + obj(Any): The object to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + make_stringify_guard = object_equal_stringify_guard + + def __init__(self, obj, graph, tracker): + super().__init__(graph, tracker) + self.value = obj + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def get_py_value(self, allow_tensor=False) -> Any: + return self.value + + +class SliceVariable(VariableBase): + """ + SliceVariable is a subclass of VariableBase used to wrap a Variable of the slice type. + + Args: + slice_(slice): The slice to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__(self, slice_: slice, graph, tracker): + super().__init__(graph, tracker) + self.value = slice_ + + @property + def debug_name(self) -> str: + return ":".join( + [ + str(self.value.start) if self.value.start is not None else "", + str(self.value.stop) if self.value.stop is not None else "", + str(self.value.step) if self.value.step is not None else "", + ] + ) + + @debug_name.setter + def debug_name(self, name): + pass + + @cached_property + def attr_proxy(self): + return self.graph.side_effects.get_proxy( + MutableDictLikeData, self.value, self.attr_proxy_getter + ) + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def get_py_value(self, allow_tensor=False): + return slice( + self.getattr("start").get_py_value(), + self.getattr("stop").get_py_value(), + self.getattr("step").get_py_value(), + ) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + result = ( + [ + StringifyExpression( + "isinstance({}, slice)", + [frame_value_tracer], + frame_value_tracer.free_vars, + ), + ] + + self.getattr("start").make_stringify_guard() + + self.getattr("stop").make_stringify_guard() + + self.getattr("step").make_stringify_guard() + ) + return result + + def _reconstruct(self, codegen: PyCodeGen): + if all( + isinstance(x, ConstantVariable) + for x in [ + self.getattr("start"), + self.getattr("stop"), + self.getattr("step"), + ] + ): + self.graph.add_global_guarded_variable(self) + self.getattr("start").reconstruct(codegen) + self.getattr("stop").reconstruct(codegen) + self.getattr("step").reconstruct(codegen) + codegen.gen_build_slice(3) + else: + super()._reconstruct(codegen) + + def setattr(self, key, val): + raise BreakGraphError("Don't support SliceVariable setattr") + + def delattr(self, key): + raise BreakGraphError("Don't support SliceVariable delattr") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, slice): + return SliceVariable(value, graph, tracker) + return None + + +class ModuleVariable(VariableBase): + """ + ModuleVariable is a subclass of VariableBase used to wrap a Variable of the module type. + + Args: + func: The module to be wrapped. + graph: The FunctionGraph object that this variable is associated with. + tracker: The Tracker object that tracks the information of this variable. + """ + + def __init__(self, func, graph, tracker): + super().__init__(graph, tracker) + self.value = func + + def get_py_value(self, allow_tensor=False): + return self.value + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, types.ModuleType): + return ModuleVariable(value, graph, tracker) + return None + + # Happened in a inline import statement. + make_stringify_guard = object_equal_stringify_guard + + +class DygraphTracerVariable(VariableBase): + # TODO(SigureMo): Remove this trick after we add CompareTracker + def __init__(self, value, graph, tracker): + super().__init__(graph, tracker) + self.value = value + + def get_py_value(self, allow_tensor=False): + return self.value + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + return [] + + @property + def main_info(self) -> dict[str, Any]: + return { + "is_none": self.value is None, + } + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, paddle.base.dygraph.tracer.Tracer): + return DygraphTracerVariable(value, graph, tracker) + return None + + +class NumpyVariable(VariableBase): + """ + NumpyVariable is a subclass of VariableBase used to wrap a Variable of the numpy type. + + Args: + value: The numpy value to be wrapped. + graph: The FunctionGraph object that this variable is associated with. + tracker: The Tracker object that tracks the information of this variable. + """ + + def __init__(self, value, graph, tracker): + super().__init__(graph, tracker) + self.value = value + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + def get_py_value(self, allow_tensor=False) -> Any: + return self.value + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.get_py_value(), np.number): + frame_value_tracer = self.tracker.trace_value_from_frame() + + def format_dtype(dtype: np.dtype): + return f"np.{str(dtype)}" + + def format_number(number: np.number): + return f"{format_dtype(number.dtype)}({str(number.item())})" + + return [ + StringifyExpression( + f"{{}} == {format_number(self.get_py_value())}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars, {"np": np}), + ), + StringifyExpression( + f"{{}}.dtype == {format_dtype(self.get_py_value().dtype)}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars, {"np": np}), + ), + ] + else: + return object_equal_stringify_guard(self) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, (np.ndarray, np.number)): + return NumpyVariable(value, graph, tracker) + return None + + +class NullVariable(VariableBase): + """ + NullVariable is a subclass of VariableBase used to represent a placeholder variable that has no value or reference associated with it. + """ + + def __init__(self): + # TODO: graph should be not None + super().__init__(None, DanglingTracker()) + + def reconstruct(self, codegen: PyCodeGen): + codegen.gen_push_null() + + +class CellVariable(VariableBase): + def __init__(self, value=None): + # TODO: graph should be not None + super().__init__( + None, DanglingTracker() + ) # should reconstruct cell variable + assert isinstance(value, (VariableBase, type(None))) + self.set_value(value) + + def reconstruct( + self, + codegen: PyCodeGen, + *, + use_tracker: bool = True, + add_to_global_guarded_vars: bool = True, + ): + raise FallbackError("Break graph in closure is not support.") + + def cell_content(self): + return self.value + + def set_value(self, value): + self.value = value + + def empty(self): + return self.value is None + + +class GlobalVariable(VariableBase): + def __init__( + self, + val_dict, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.proxy = self.graph.side_effects.get_proxy( + MutableDictLikeData, val_dict, self.proxy_getter + ) + + def proxy_getter(self, proxy: MutableDictLikeData, key: Any): + if key not in proxy.original_data: + return MutableDictLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GlobalTracker(key), + ) + + def get_value(self): + return dict(self.proxy.get_all().items()) + + def keys(self): + return self.proxy.get_all().keys() + + def get(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} to get value." + ) + return self.proxy.get(key) + + def set(self, key, value): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {value} to set value." + ) + self.proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + + def delete(self, key): + self.proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + + +class FunctionGlobalVariable(GlobalVariable): + def __init__( + self, + fn: FunctionVariable, + val_dict: dict[str, Any], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(val_dict, graph, tracker) + self.fn = fn + + def proxy_getter(self, proxy: MutableDictLikeData, key: Any): + from ..opcode_inline_executor import FunctionGlobalTracker + + if key not in proxy.original_data: + return MutableDictLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=FunctionGlobalTracker(self.fn, key), + ) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py new file mode 100644 index 00000000000000..819580710beba8 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/callable.py @@ -0,0 +1,759 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import operator +import types +from functools import reduce +from typing import TYPE_CHECKING, Any, Callable + +import paddle + +from .... import psdb +from ....profiler import EventGuard +from ....utils import ( + is_break_graph_api, + is_break_graph_tensor_methods, + is_builtin_fn, + is_paddle_api, + magic_method_builtin_dispatch, +) +from ....utils.exceptions import BreakGraphError, FallbackError, SotErrorBase +from ..dispatcher import Dispatcher +from ..guard import ( + StringifyExpression, + check_guard, + object_equal_stringify_guard, + union_free_vars, +) +from ..tracker import ( + ConstTracker, + CreateLayerTracker, + DanglingTracker, + DummyTracker, + GetAttrTracker, + GetItemTracker, + GetIterTracker, + Tracker, +) +from .base import VariableBase, VariableFactory +from .basic import ConstantVariable, PrintStmtVariable, SliceVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +PD_ALL_CONTAINERS = (paddle.nn.Sequential, paddle.nn.LayerList) +PD_SEQ_CONTAINERS = (paddle.nn.Sequential, paddle.nn.LayerList) + + +class CallableVariable(VariableBase): + """ + CallableVariable is a subclass of VariableBase used to wrap a callable variable. + + Args: + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__(self, graph: FunctionGraph, tracker: Tracker): + super().__init__(graph, tracker) + + def __call__(self, /, *args, **kwargs) -> VariableBase: + """Why we need '/' to make self positional only? + + If kwargs have {'self': xxx}, this function call raise a error. + See: test_str_format.py for details. + """ + with EventGuard(f"call_function: {self.__class__.__name__}"): + return self.call_function(*args, **kwargs) + + def call_function(self, /, *args, **kwargs): + raise NotImplementedError("call_function is not implemented.") + + +class FunctionVariable(CallableVariable): + """ + FunctionVariable is a subclass of CallableVariable used to wrap a function variable. + + Args: + fn (Callable[..., Any]): The function to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = fn + + def get_py_value(self, allow_tensor=False): + return self.value + + def get_code(self) -> types.CodeType: + return self.value.__code__ + + def bind(self, instance: VariableBase, name: str): + method_var = MethodVariable( + instance, + self, + graph=self.graph, + tracker=GetAttrTracker(instance, name), + ) + class_var = VariableFactory.from_value( + instance.get_py_type(), + graph=self.graph, + tracker=GetAttrTracker(instance, "__class__"), + ) + assert class_var is not None + self.tracker = GetAttrTracker(class_var, name) + return method_var + + make_stringify_guard = object_equal_stringify_guard + + +class UserDefinedFunctionVariable(FunctionVariable): + """ + UserDefinedFunctionVariable is a subclass of FunctionVariable used to wrap a user-defined function. + + Args: + fn (Callable[..., Any]): The user-defined function to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def handle_psdb_function(self, /, *args, **kwargs): + # special function for inner debug. + if self.value is psdb.assert_true: + return ConstantVariable.wrap_literal( + self.value(args[0].value), self.graph + ) + elif self.value is psdb.print: + sot_prefix = ConstantVariable.wrap_literal("[SOT]", self.graph) + self.graph.add_print_variables( + PrintStmtVariable(([sot_prefix, *args], kwargs), self.graph) + ) + return ConstantVariable.wrap_literal(None, self.graph) + elif self.value is psdb.breakpoint: + # do nothing. just return None. + from ...breakpoint import BM + + BM.locate(BM.executors[-1]) + BM.add(BM.cur_exe._code.co_filename, BM.cur_exe._current_line) + return ConstantVariable.wrap_literal(None, self.graph) + elif self.value is psdb.breakgraph: + raise BreakGraphError("breakgraph by psdb.breakgraph") + elif self.value is psdb.fallback: + raise FallbackError("fallback by psdb.fallback") + elif self.value is psdb.in_sot: + return ConstantVariable.wrap_literal(True, self.graph) + return None + + def call_function(self, /, *args, **kwargs) -> VariableBase: + from ..opcode_inline_executor import OpcodeInlineExecutor + + result = self.handle_psdb_function(*args, **kwargs) + if result is not None: + return result + + checkpoint = self.graph.save_memo() + try: + inline_executor = OpcodeInlineExecutor(self, *args, **kwargs) + with EventGuard( + f"Inline Call: {inline_executor._code.co_name.replace('<', '(').replace('>', ')')}, file {inline_executor._code.co_filename}, line {int(inline_executor._code.co_firstlineno)}" + ): + output = inline_executor.inline_call() + except SotErrorBase as e: + self.graph.restore_memo(checkpoint) + raise BreakGraphError( + f"({e}) raised while inline call {self.value.__code__}." + ) + return output + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, (types.FunctionType)): + return UserDefinedFunctionVariable(value, graph, tracker) + if isinstance( + value, paddle.jit.dy2static.program_translator.StaticFunction + ): + return UserDefinedFunctionVariable( + value.dygraph_function, graph, tracker + ) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + +class PaddleApiVariable(FunctionVariable): + """ + PaddleApiVariable is a subclass of FunctionVariable used to wrap a paddlepaddle API function. + + Args: + fn (Callable[..., Any]): The paddlepaddle API to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, /, *args, **kwargs): + if is_break_graph_api(self.value): + raise BreakGraphError( + f"breakgraph by unsupport function: {self.value.__name__}" + ) + return self.graph.call_paddle_api(self.value, *args, **kwargs) + + @VariableFactory.register_from_value( + successor="UserDefinedFunctionVariable" + ) + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if callable(value) and is_paddle_api(value): + return PaddleApiVariable(value, graph, tracker) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + make_stringify_guard = object_equal_stringify_guard + + +class TensorFunctionVariable(FunctionVariable): + """ + TensorFunctionVariable is a subclass of FunctionVariable used to wrap a method of a tensor. + + Args: + method_name (str): The name of the tensor method to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, method_name: str, graph: FunctionGraph, tracker: Tracker + ): + fn = getattr(paddle.static.Variable, method_name) + super().__init__(fn, graph, tracker) + self.method_name = method_name + + def call_function(self, /, *args, **kwargs): + if is_break_graph_tensor_methods(self.method_name): + raise BreakGraphError() + return self.graph.call_tensor_method(self.method_name, *args, **kwargs) + + def bind(self, instance: VariableBase, name: str): + method_var = MethodVariable( + instance, + self, + graph=self.graph, + tracker=GetAttrTracker(instance, name), + ) + class_var = VariableFactory.from_value( + instance.get_py_type(), + graph=self.graph, + tracker=ConstTracker(instance.get_py_type()), + ) + assert class_var is not None + self.tracker = GetAttrTracker(class_var, name) + return method_var + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + +class MethodVariable(CallableVariable): + """ + MethodVariable is a subclass of CallableVariable used to wrap a method variable. + + Args: + bound_instance (VariableBase): The instance of the method. + fn (VariableBase): The method to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + method_name (str): The name of the method to be wrapped. + """ + + def __init__( + self, + bound_instance: VariableBase, + fn: VariableBase, + graph: FunctionGraph, + tracker: Tracker, + *, + method_name: str | None = None, + ): + super().__init__(graph, tracker) + self.bound_instance = bound_instance + self.fn = fn + self.method_name = method_name + + def get_py_value(self, allow_tensor=False): + return self.fn.get_py_value().__get__( + self.bound_instance.get_py_value(allow_tensor), + self.bound_instance.get_py_value(allow_tensor).__class__, + ) + + def _reconstruct(self, pycode_gen): + assert self.method_name is not None + self.tensor.reconstruct(pycode_gen) + pycode_gen.gen_load_attr(self.method_name) + + def call_function(self, /, *args, **kwargs): + return self.fn(*(self.bound_instance, *args), **kwargs) + + @staticmethod + def wrap_method( + value: types.MethodType, + *, + graph: FunctionGraph, + tracker: Tracker, + instance: VariableBase | None = None, + fn: VariableBase | None = None, + method_name: str | None = None, + ): + # NOTE(SigureMo): Since the method_self need method_var as the obj + # of the tracker, we need to temporarily set the tracker of method_self + # to DummyTracker, and set it to GetAttrTracker after method_var is created. + instance_var = ( + VariableFactory.from_value(value.__self__, graph, DanglingTracker()) + if instance is None + else instance + ) + + fn_var = ( + VariableFactory.from_value(value.__func__, graph, DanglingTracker()) + if fn is None + else fn + ) + + method_var = MethodVariable( + instance_var, + fn_var, + method_name=method_name, + graph=graph, + tracker=tracker, + ) + if instance is None: + instance_var.tracker = GetAttrTracker(method_var, "__self__") + if fn is None: + fn_var.tracker = GetAttrTracker(method_var, "__func__") + return method_var + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if inspect.ismethod(value): + return MethodVariable.wrap_method( + value=value, tracker=tracker, graph=graph + ) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "method": self.method_name, + } + + +class LayerVariable(CallableVariable): + """ + LayerVariable is a subclass of CallableVariable used to wrap a layer. + + Args: + layer (paddle.nn.Layer): The layer to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.value = layer + + def get_py_value(self, allow_tensor=False): + return self.value + + def call_function(self, /, *args, **kwargs): + fn_var = UserDefinedFunctionVariable( + self.value.__class__.__call__, + self.graph, + GetAttrTracker(self, "__call__"), + ) + + return fn_var(*(self, *args), **kwargs) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + return [ + StringifyExpression( + f"id({{}}) == {id(self.get_py_value())}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + StringifyExpression( + f"{{}}.training == {self.get_py_value().training}", + [frame_value_tracer], + union_free_vars(frame_value_tracer.free_vars), + ), + ] + + +class ContainerLayerVariable(LayerVariable): + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + def __len__(self): + return len(self.value) + + def len(self): + return ConstantVariable(len(self), self.graph, DummyTracker([self])) + + def getitem(self, key): + if isinstance(self.value, PD_SEQ_CONTAINERS) and isinstance( + key, SliceVariable + ): + try: + slice_py_value = key.get_py_value() + new_layer_list = self.value[slice_py_value] + self.graph.add_global_guarded_variable(key) + return VariableFactory.from_value( + new_layer_list, + self.graph, + GetItemTracker(self, slice_py_value), + ) + except Exception as e: + raise BreakGraphError( + f"call {self.value.__class__.__name__}.__getitem__ with slice as key, and slice with py value failed: {e}." + ) + + else: + return super().getitem(key) + + def get_iter(self): + if isinstance(self.value, PD_SEQ_CONTAINERS): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + else: + return super().get_iter() + + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.value, PD_SEQ_CONTAINERS): + frame_value_tracer = self.tracker.trace_value_from_frame() + + len_guard = StringifyExpression( + f"len({{}}) == {len(self.value)}", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + + guards = [len_guard] + for idx, layer in enumerate(self.value): + layer_variable = VariableFactory.from_value( + layer, self.graph, GetItemTracker(self, idx) + ) + guards.extend(layer_variable.make_stringify_guard()) + + return guards + else: + return super().make_stringify_guard() + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__class__.__name__, + } + + @VariableFactory.register_from_value(successor="PaddleLayerVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, PD_ALL_CONTAINERS): + return ContainerLayerVariable(value, graph, tracker) + return None + + +class PaddleLayerVariable(LayerVariable): + """ + PaddleLayerVariable is a subclass of LayerVariable used to wrap a paddlepaddle layer. + + Args: + layer (paddle.nn.Layer): The paddle built-in layer to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + def call_function(self, /, *args, **kwargs): + self.graph.add_global_guarded_variable(self) + return self.graph.call_layer(self, *args, **kwargs) + + def make_stringify_guard(self) -> list[StringifyExpression]: + if isinstance(self.tracker, CreateLayerTracker): + return reduce( + operator.add, + [var.make_stringify_guard() for var in self.tracker.inputs], + ) + else: + return super().make_stringify_guard() + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__class__.__name__, + } + + @VariableFactory.register_from_value(successor="UserDefinedLayerVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + # TODO(SigureMo): Add a more common way to check if a value is a paddle builtin layer. + if isinstance(value, paddle.nn.Layer): + # If there is a user-defined behavior, such as a container class layer + # or a hook on the layer, it needs to be converted to UserDefinedLayerVariable, + # otherwise converted to PaddleLayerVariable + if ( + hasattr(value, "_forward_pre_hooks") + and value._forward_pre_hooks + or hasattr(value, "_forward_post_hooks") + and value._forward_post_hooks + ): + return None + if value.__module__.startswith("paddle.nn."): + return PaddleLayerVariable(value, graph, tracker) + return None + + +class UserDefinedLayerVariable(LayerVariable): + """ + UserDefinedLayerVariable is a subclass of LayerVariable used to wrap a user-defined layer. + + Args: + layer (paddle.nn.Layer): The user-defined layer to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, layer: paddle.nn.Layer, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(layer, graph, tracker) + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__class__.__name__, + } + + @VariableFactory.register_from_value(successor="PaddleApiVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if isinstance(value, paddle.nn.Layer): + return UserDefinedLayerVariable(value, graph, tracker) + return None + + +class BuiltinVariable(FunctionVariable): + """ + BuiltinVariable is a subclass of FunctionVariable used to wrap a built-in function. + Args: + fn (Callable[..., Any]): The built-in function to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + self.value = fn + + def call_function(self, /, *args, **kwargs): + # Lookup the handler from dispatcher + handler = Dispatcher.dispatch(self.value, *args, **kwargs) + if handler is not None: + return handler(*args, **kwargs) + + # Try to inline call the magic function + magic_methods = magic_method_builtin_dispatch(self.value) + for magic_method in magic_methods: + sorted_args = args + if magic_method.is_reverse: + sorted_args = sorted_args[::-1] + arg_type = sorted_args[0].get_py_type() + if hasattr(arg_type, magic_method.name): + class_fn = getattr(arg_type, magic_method.name) + class_var = VariableFactory.from_value( + arg_type, + self.graph, + GetAttrTracker(args[0], "__class__"), + ) + assert isinstance(class_var, VariableBase) + fn_var = VariableFactory.from_value( + class_fn, + self.graph, + GetAttrTracker(class_var, class_fn.__name__), + ) + assert isinstance(fn_var, VariableBase) + return fn_var(*args) + + # Break graph if neither of the above conditions is met + arg_types = ", ".join([type(arg).__name__ for arg in args]) + fn_name = ( + self.value.__name__ + if hasattr(self.value, '__name__') + else self.value + ) + raise BreakGraphError( + f"Not support builtin function: {fn_name} with args: Args({arg_types})" + ) + + @VariableFactory.register_from_value(successor="ClassVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if is_builtin_fn(value): + return BuiltinVariable(value, graph, tracker) + return None + + @property + def main_info(self) -> dict[str, Any]: + return { + "name": self.value.__name__, + } + + +class UserDefinedGeneratorVariable(FunctionVariable): + """ + UserDefinedGeneratorVariable is a subclass of FunctionVariable used to wrap a user-defined generator. + Args: + fn (Callable[..., Any]): The user-defined generator to be wrapped. + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker + ): + super().__init__(fn, graph, tracker) + + def call_function(self, /, *args, **kwargs): + iter_ = self.value(*args, **kwargs) + var = VariableFactory.from_value( + iter_, self.graph, DummyTracker([self]) + ) + return var + + @property + def main_info(self) -> dict[str, Any]: + return {"name": self.value.__name__} + + @VariableFactory.register_from_value( + successor="UserDefinedFunctionVariable" + ) + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if inspect.isgeneratorfunction(value): + return UserDefinedGeneratorVariable(value, graph, tracker) + return None + + +class ClassVariable(CallableVariable): + def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker): + super().__init__(graph, tracker) + self.value = class_ + + def get_py_value(self, allow_tensor=False): + return self.value + + def call_function(self, /, *args, **kwargs): + new_object = self.value.__new__(self.value) + + # do not have init function + if self.value.__init__ is object.__init__: + return VariableFactory.from_value( + new_object, self.graph, DummyTracker([self]) + ) + + if not hasattr(self.value.__init__, "__code__"): + fn_var = BuiltinVariable( + self.value.__init__, + self.graph, + GetAttrTracker(self, "__init__"), + ) + else: + fn_var = UserDefinedFunctionVariable( + self.value.__init__, + self.graph, + GetAttrTracker(self, "__init__"), + ) + + # need classify variable type here? + new_object_variable = VariableFactory.from_value( + new_object, + self.graph, + DummyTracker([self] + list(args) + list(kwargs.values())), + ) + fn_var(new_object_variable, *args, **kwargs) + return new_object_variable + + make_stringify_guard = object_equal_stringify_guard + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if inspect.isclass(value): + return ClassVariable(value, graph, tracker) + return None + + +class PaddleLayerClassVariable(ClassVariable): + def __init__(self, class_: type, graph: FunctionGraph, tracker: Tracker): + super().__init__(class_, graph, tracker) + + def call_function(self, /, *args, **kwargs): + input_py_args = [var.get_py_value() for var in args] + input_py_kwargs = {k: v.get_py_value() for k, v in kwargs.items()} + new_layer = self.value(*input_py_args, **input_py_kwargs) + return PaddleLayerVariable( + new_layer, self.graph, CreateLayerTracker(self, args, kwargs) + ) + + @VariableFactory.register_from_value(successor="ClassVariable") + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if ( + inspect.isclass(value) + and issubclass(value, paddle.nn.Layer) + and value.__module__.startswith("paddle.nn.") + ): + return PaddleLayerClassVariable(value, graph, tracker) + return None diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/container.py b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py new file mode 100644 index 00000000000000..b1c318e9187bd1 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/container.py @@ -0,0 +1,1011 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +from collections import OrderedDict +from functools import reduce +from typing import TYPE_CHECKING, Any + +from ....utils.exceptions import FallbackError, InnerError +from ..dispatcher import Dispatcher +from ..guard import StringifyExpression, check_guard +from ..mutable_data import MutableDictLikeData, MutableListLikeData +from ..pycode_generator import PyCodeGen +from ..tracker import ( + ConstTracker, + DanglingTracker, + DummyTracker, + GetItemTracker, + GetIterTracker, + Tracker, +) +from .base import ConstTypes, VariableBase, VariableFactory +from .basic import ConstantVariable +from .callable import BuiltinVariable, UserDefinedFunctionVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + + +class ContainerVariable(VariableBase): + """ + ContainerVariable is a wrapper for container types, such as range, list, tuple, dict. + """ + + @property + def init_value(self): + return self.value + + def get_items(self) -> list[VariableBase]: + raise FallbackError('ContainerVariable.get_items do not implement') + + def get_wrapped_items(self): + raise FallbackError( + "ContainerVariable.get_wrapped_items do not implement" + ) + + def __len__(self): + raise FallbackError('ContainerVariable.__len__ do not implement') + + def len(self): + return ConstantVariable(len(self), self.graph, DummyTracker([self])) + + def __bool__(self) -> bool: + return len(self) > 0 + + def bool(self): + return ConstantVariable(bool(self), self.graph, DummyTracker([self])) + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + type_guard = StringifyExpression( + f"isinstance({{}}, {self.get_py_type().__name__})", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + len_guard = StringifyExpression( + f"len({{}}) == {len(self.init_value)}", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + if isinstance(self, (ListVariable, TupleVariable)): + guard_variables = self.proxy.reproduce(0) + + elif isinstance(self, DictVariable): + guard_variables = filter( + lambda var: not isinstance(var, MutableDictLikeData.Empty), + self.proxy.reproduce(0).values(), + ) + else: + raise InnerError(f"Unsupported container type: {type(self)}") + return reduce( + operator.add, + [[type_guard, len_guard]] + + [item.make_stringify_guard() for item in guard_variables], + ) + + +class ListVariable(ContainerVariable): + """ + ListVariable is a wrapper for list and contains common APIs for list methods + + Args: + val_list(List[VariableBase]): the list to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_list: list[VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + + # everything in stack is VariableBase, so just accept the input list is ok + self.proxy = self.graph.side_effects.get_proxy( + MutableListLikeData, val_list, self.proxy_getter + ) + self.value = val_list + + def proxy_getter(self, proxy: MutableListLikeData, key: Any): + if key < 0 or key >= len(proxy.original_data): + return MutableListLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GetItemTracker(self, key, changed=proxy.has_changed), + ) + + def get_py_value(self, allow_tensor=False): + items = self.proxy.get_all() + return [item.get_py_value(allow_tensor) for item in items] + + def get_py_type(self): + return list + + def _reconstruct(self, codegen: PyCodeGen): + size = len(self) + for idx in range(size): + Dispatcher.call(operator.getitem, self, idx).reconstruct(codegen) + codegen.gen_build_list(size) + + def get_items(self): + size = len(self) + return [ + Dispatcher.call(operator.getitem, self, idx) for idx in range(size) + ] + + def get_wrapped_items(self): + return self.get_items() + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + @property + def main_info(self) -> dict[str, Any]: + return { + "len": len(self), + } + + def __len__(self): + return self.proxy.length + + def getitem(self, key): + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + if isinstance(key, int): + res = self.proxy.get(key) + if self.proxy.is_empty(res): + raise InnerError(f"List {self} out of range (index={key})") + return res + elif isinstance(key, slice): + items = self.proxy.get_all() + return VariableFactory.from_value( + items[key], + self.graph, + tracker=GetItemTracker( + self, key, changed=self.proxy.has_changed + ), + ) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} for ListVariable" + ) + + def setitem(self, key, value): + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {value} to set value." + ) + if isinstance(key, int): + self.proxy.set(key, value) + elif isinstance(key, slice) and isinstance( + value, (ListVariable, TupleVariable) + ): + start, end, step = key.indices(self.proxy.length) + indices = list(range(start, end, step)) + if step == 1: + # replace a continuous range + for i, idx in enumerate(indices): + self.proxy.delete(idx - i) + for i, item in enumerate(value.get_wrapped_items()): + self.proxy.insert(start + i, item) + else: + # replace some elements + if len(indices) != len(value): + raise InnerError( + f"Attempt to replace {len(indices)} items with {len(value)}" + ) + for i, idx in enumerate(indices): + self.proxy.set(idx, value[i]) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} and value type {value.__class__.__name__} for ListVariable" + ) + + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def __delitem__(self, key): + return self.delitem(key) + + def delitem(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: received {key} as key to delete." + ) + self.proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def insert(self, index: int, value: VariableBase): + self.proxy.insert(index, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def append(self, value: VariableBase): + self.insert(self.proxy.length, value) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def extend(self, data): + for item in data.proxy.get_all(): + self.append(item) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def concat(self, list_): + assert isinstance(list_, ListVariable) + return ListVariable( + self.proxy.get_all() + list_.proxy.get_all(), + self.graph, + DummyTracker([self, list_]), + ) + + def repeat(self, length): + assert isinstance(length, ConstantVariable) + return ListVariable( + self.proxy.get_all() * length.value, + self.graph, + DummyTracker([self, length]), + ) + + def pop(self, index: ConstantVariable | None = None): + if index is None: + index = ConstantVariable.wrap_literal(-1, self.graph) + res = self.proxy.get(index.get_py_value()) + self.proxy.delete(index.get_py_value()) + self.graph.side_effects.record_proxy_variable(self) + return res + + def copy(self): + return ListVariable( + self.proxy.get_all(), + self.graph, + DummyTracker([self]), + ) + + def clear(self): + for idx in range(self.proxy.length): + self.delitem(0) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def remove(self, value): + for idx in range(self.proxy.length): + if self[idx].get_py_value(allow_tensor=True) == value.get_py_value( + allow_tensor=True + ): + self.delitem(idx) + break + else: + raise InnerError(f"List {self} does not contain {value}") + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def sort(self, key=None, reverse=None): + if ( + key is None + or isinstance(key, ConstantVariable) + and key.get_py_value() is None + ): + key = UserDefinedFunctionVariable( + lambda x: x, self.graph, DanglingTracker() + ) + assert key is not None + if reverse is None: + reverse = ConstantVariable.wrap_literal(False, self.graph) + + permutation = list(range(self.proxy.length)) + permutation.sort( + key=lambda x: key.get_py_value()( + Dispatcher.call(operator.getitem, self, x).value + ), + reverse=reverse.get_py_value(), + ) + self.proxy.permutate(permutation) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def reverse(self): + permutation = list(range(self.proxy.length)) + permutation.reverse() + self.proxy.permutate(permutation) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def count(self, value: VariableBase): + count: int = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + count += 1 + continue + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + count += 1 + continue + + return ConstantVariable(count, self.graph, DummyTracker([self, value])) + + def index(self, value: VariableBase): + res = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + res += 1 + + return ConstantVariable(-1, self.graph, DummyTracker([self, value])) + + def max(self): + if len(self) == 0: + raise ValueError("max() arg is an empty sequence") + res = self[0] + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + gt = BuiltinVariable(operator.gt, self.graph, DanglingTracker())( + index_value, res + ) + if gt.get_py_value() is True: + res = index_value + return res + + def min(self): + if len(self) == 0: + raise ValueError("max() arg is an empty sequence") + res = self[0] + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + lt = BuiltinVariable(operator.lt, self.graph, DanglingTracker())( + index_value, res + ) + if lt.get_py_value() is True: + res = index_value + return res + + def getattr(self, name: str, default=None): + from .callable import BuiltinVariable + + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + + method_name_to_builtin_fn = { + "insert": list.insert, + "append": list.append, + "extend": list.extend, + "pop": list.pop, + "copy": list.copy, + "clear": list.clear, + "remove": list.remove, + "sort": list.sort, + "reverse": list.reverse, + "count": list.count, + "index": list.index, + } + + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise FallbackError(f"attribute {name} for list is not implemented") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + # Note(SigureMo): Why not use isinstance? + # Because user may define a class that inherit from list. + # We should convert it to ObjectVariable instead of ListVariable. + if type(value) is list: # noqa: E721 + return ListVariable(value, graph=graph, tracker=tracker) + return None + + +class TupleVariable(ContainerVariable): + """ + TupleVariable is a wrapper for tuple and contains common APIs for tuple methods. + + Args: + val_tuple(tuple[VariableBase, ...]): the tuple to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_tuple: tuple[VariableBase, ...], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + + self.proxy = self.graph.side_effects.get_proxy( + MutableListLikeData, list(val_tuple), self.proxy_getter + ) + self.value = val_tuple + + def getattr(self, name: str, default=None): + from .callable import BuiltinVariable + + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + + method_name_to_builtin_fn = { + "count": tuple.count, + "index": tuple.index, + } + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise FallbackError( + f"attribute {name} for tuple is not implemented" + ) + + def proxy_getter(self, proxy: MutableListLikeData, key: Any): + if key < 0 or key >= len(proxy.original_data): + return MutableListLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GetItemTracker(self, key, changed=False), + ) + + def get_py_value(self, allow_tensor=False): + return tuple( + self[idx].get_py_value(allow_tensor) for idx in range(len(self)) + ) + + def get_py_type(self): + return tuple + + def _reconstruct(self, codegen: PyCodeGen): + size = len(self) + for idx in range(size): + Dispatcher.call(operator.getitem, self, idx).reconstruct(codegen) + codegen.gen_build_tuple(size) + + def get_items(self): + size = len(self) + return [ + Dispatcher.call(operator.getitem, self, idx) for idx in range(size) + ] + + def get_wrapped_items(self): + return tuple(self.get_items()) + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + @property + def main_info(self) -> dict[str, Any]: + return { + "len": len(self), + } + + def __len__(self): + return self.proxy.length + + def getitem(self, key): + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + if isinstance(key, int): + res = self.proxy.get(key) + if self.proxy.is_empty(res): + raise InnerError(f"List {self} out of range (index={key})") + return res + elif isinstance(key, slice): + return TupleVariable( + tuple(self.proxy.get_all())[key], + self.graph, + tracker=GetItemTracker(self, key, changed=False), + ) + else: + raise InnerError( + f"Unsupported key type {key.__class__.__name__} for TupleVariable" + ) + + def setitem(self, key, value): + raise InnerError( + f"[{self.__class__.__name__}]: setitem is not allowed." + ) + + def __delitem__(self, key): + return self.delitem(key) + + def delitem(self, key): + raise InnerError( + f"[{self.__class__.__name__}]: delitem is not allowed." + ) + + def concat(self, tuple_): + assert isinstance(tuple_, TupleVariable) + new_tuple_variable = TupleVariable( + tuple(self.proxy.get_all() + tuple_.proxy.get_all()), + self.graph, + DummyTracker([self, tuple_]), + ) + return new_tuple_variable + + def repeat(self, length): + assert isinstance(length, ConstantVariable) + new_tuple_variable = TupleVariable( + tuple(self.proxy.get_all()) * length.value, + self.graph, + DummyTracker([self, length]), + ) + return new_tuple_variable + + def count(self, value: VariableBase): + count: int = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + count += 1 + continue + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + count += 1 + continue + + return ConstantVariable(count, self.graph, DummyTracker([self, value])) + + def index(self, value: VariableBase): + res = 0 + getitem = BuiltinVariable( + operator.getitem, self.graph, DanglingTracker() + ) + for index in range(len(self)): + index_value = getitem(self, index) + if index_value.id == value.id: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + eq = BuiltinVariable(operator.eq, self.graph, DanglingTracker())( + index_value, value + ) + eq_bool = BuiltinVariable(bool, self.graph, DanglingTracker())(eq) + assert isinstance( + eq_bool, ConstantVariable + ), "bool should return ConstantVariable" + if eq.get_py_value() is True: + return ConstantVariable( + res, self.graph, DummyTracker([self, value]) + ) + res += 1 + + return ConstantVariable(-1, self.graph, DummyTracker([self, value])) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) is tuple: + return TupleVariable(value, graph, tracker) + return None + + +class RangeVariable(ContainerVariable): + """ + RangeVariable is a wrapper for range. + + Args: + val_range(range): the range to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_range: range, + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + self.value = val_range + + def get_py_type(self): + return range + + def get_py_value(self, allow_tensor=False): + return self.value + + def getitem(self, key): + self.graph.add_global_guarded_variable(self) + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + retval = self.value[key] + return ConstantVariable.wrap_literal(retval, self.graph) + + def get_items(self): + size = len(self) + return [self[idx] for idx in range(size)] + + def get_wrapped_items(self): + return self.get_items() + + def get_iter(self): + from .iter import SequenceIterVariable + + return SequenceIterVariable(self, self.graph, GetIterTracker(self)) + + def __len__(self): + return len(self.value) + + def _reconstruct(self, codegen: PyCodeGen): + codegen.gen_load_global("range", push_null=True) + # The start default value is 0, step is 1 + # So we can always construct range with 3 args + codegen.gen_load_const(self.value.start) + codegen.gen_load_const(self.value.stop) + codegen.gen_load_const(self.value.step) + codegen.gen_call_function(3) + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) is range: + return RangeVariable(value, graph, tracker) + return None + + @check_guard + def make_stringify_guard(self) -> list[StringifyExpression]: + frame_value_tracer = self.tracker.trace_value_from_frame() + + return [ + StringifyExpression( + "isinstance({0}, range) and " + + f"{{0}}.start == {self.init_value.start} and " + + f"{{0}}.stop == {self.init_value.stop} and " + + f"{{0}}.step == {self.init_value.step}", + [frame_value_tracer], + frame_value_tracer.free_vars, + ) + ] + + @property + def debug_name(self) -> str: + return ":".join( + [ + str(self.value.start) if self.value.start is not None else "", + str(self.value.stop) if self.value.stop is not None else "", + str(self.value.step) if self.value.step is not None else "", + ] + ) + + @debug_name.setter + def debug_name(self, name): + pass + + @property + def main_info(self) -> dict[str, Any]: + return {"value": self.value} + + +class DictVariable(ContainerVariable): + """ + DictVariable is a wrapper for dict and contains common APIs for dict methods + + Args: + val_dict(dict[object, VariableBase]): the dict to wrap + graph(FunctionGraph): The FunctionGraph object that this variable is associated with. + tracker(Tracker): The Tracker object that tracks the information of this variable. + """ + + def __init__( + self, + val_dict: dict[object, VariableBase], + graph: FunctionGraph, + tracker: Tracker, + ): + super().__init__(graph, tracker) + + self.proxy = self.graph.side_effects.get_proxy( + MutableDictLikeData, val_dict, self.proxy_getter + ) + self.value = val_dict + + def proxy_getter(self, proxy: MutableDictLikeData, key: Any): + if key not in proxy.original_data: + return MutableDictLikeData.Empty() + return VariableFactory.from_value( + proxy.original_data[key], + self.graph, + tracker=GetItemTracker(self, key, changed=proxy.has_changed), + ) + + def get_py_value(self, allow_tensor=False): + return { + key: value.get_py_value(allow_tensor) + for key, value in self.proxy.get_all().items() + } + + def get_py_type(self): + return dict + + def _reconstruct(self, codegen: PyCodeGen): + from .basic import ConstantVariable + + size = len(self) + for key in self.proxy.get_all().keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + key_var = ConstantVariable.wrap_literal(key, self.graph) + value_var = self[key] + key_var.reconstruct(codegen) + value_var.reconstruct(codegen) + codegen.gen_build_map(size) + + def get_items(self): + items = [] + for key in self.proxy.get_all().keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + key_var = VariableFactory.from_value( + key, self.graph, tracker=ConstTracker(key) + ) + value_var = self[key] + items.extend([key_var, value_var]) + return items + + def get_wrapped_items(self): + items = {} + for key in self.proxy.get_all().keys(): + if not isinstance(key, ConstTypes): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + items[key] = self[key] + return items + + def get_iter(self): + return self.keys() + + @property + def main_info(self) -> dict[str, Any]: + return { + "len": len(self), + } + + def __len__(self): + return len(self.proxy.get_all()) + + def get(self, key, default=None): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} to get value." + ) + + if default is None: + return Dispatcher.call(operator.getitem, self, key) + + if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): + assert isinstance(default, VariableBase) + return default + + return Dispatcher.call(operator.getitem, self, key) + + def getitem(self, key): + self.graph.add_global_guarded_variable(key) + key = key.get_py_value() + return self.proxy.get(key) + + def setitem(self, key, value): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key." + ) + + if not isinstance(value, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {value} to set value." + ) + + self.proxy.set(key, value) + self.graph.side_effects.record_proxy_variable(self) + + return ConstantVariable.wrap_literal(None, self.graph) + + def clear(self): + # TODO: Replace with self.proxy.clear() + for key in self.value: + self.delitem(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def __delitem__(self, key): + return self.delitem(key) + + def delitem(self, key): + if isinstance(key, VariableBase): + raise InnerError( + f"[{self.__class__.__name__}]: recieved {key} as key to delete." + ) + self.proxy.delete(key) + self.graph.side_effects.record_proxy_variable(self) + return ConstantVariable.wrap_literal(None, self.graph) + + def keys(self): + from .iter import SequenceIterVariable + + raw_list = [ + ConstantVariable(x, self.graph, ConstTracker(x)) + for x in self.proxy.get_all().keys() + ] + key_list = ListVariable(raw_list, self.graph, DummyTracker(raw_list)) + assert key_list is not None + return SequenceIterVariable( + key_list, self.graph, DummyTracker([key_list]) + ) + + def values(self): + from .iter import SequenceIterVariable + + raw_list = list(self.get_wrapped_items().values()) + value_list = ListVariable(raw_list, self.graph, DummyTracker([self])) + assert value_list is not None + return SequenceIterVariable( + value_list, self.graph, DummyTracker([value_list]) + ) + + def items(self): + from .iter import SequenceIterVariable + + keys = [ + ConstantVariable(x, self.graph, ConstTracker(x)) + for x in self.proxy.get_all().keys() + ] + values = list(self.get_wrapped_items().values()) + raw_list = list(zip(keys, values)) + item_list = ListVariable(raw_list, self.graph, DummyTracker([self])) + assert item_list is not None + return SequenceIterVariable( + item_list, self.graph, DummyTracker([item_list]) + ) + + def update(self, data: DictVariable): + for key, value in data.proxy.get_all().items(): + self.setitem(key, value) + return ConstantVariable.wrap_literal(None, self.graph) + + def copy(self): + new_dict_variable = DictVariable( + self.get_wrapped_items(), self.graph, DummyTracker([self]) + ) + return new_dict_variable + + def setdefault(self, key, default=None): + if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): + if default is None: + self.setitem( + key, ConstantVariable.wrap_literal(default, self.graph) + ) + else: + self.setitem(key, default) + + return Dispatcher.call(operator.getitem, self, key) + + def pop(self, key, default=None): + if isinstance(self.proxy.get(key), MutableDictLikeData.Empty): + assert isinstance(default, VariableBase) + return default + + # default is not None, or key is in dict + temp_value = Dispatcher.call(operator.getitem, self, key) + self.delitem(key) + return temp_value + + def popitem(self): + key = self.keys().hold.get_py_value()[-1] + value = Dispatcher.call(operator.getitem, self, key) + # TODO: key, value should be VariableBase but key maybe a int + # assert isinstance(key, VariableBase), key + # assert isinstance(value, VariableBase), value + new_tuple_variable = TupleVariable( + (key, value), self.graph, DummyTracker([self]) + ) + self.delitem(key) + return new_tuple_variable + + def getattr(self, name: str, default=None): + from .callable import BuiltinVariable + + if default is not None: + raise FallbackError( + "default argument for getattr is not implemented" + ) + + method_name_to_builtin_fn = { + "keys": dict.keys, + "values": dict.values, + "items": dict.items, + "update": dict.update, + "setdefault": dict.setdefault, + "get": dict.get, + "copy": dict.copy, + "clear": dict.clear, + "pop": dict.pop, + "popitem": dict.popitem, + } + + if name in method_name_to_builtin_fn: + builtin_fn = method_name_to_builtin_fn[name] + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) + else: + raise FallbackError(f"attribute {name} for dict is not implemented") + + @VariableFactory.register_from_value() + def from_value(value: Any, graph: FunctionGraph, tracker: Tracker): + if type(value) in (dict, OrderedDict): + return DictVariable(value, graph=graph, tracker=tracker) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/iter.py b/python/paddle/jit/sot/opcode_translator/executor/variables/iter.py new file mode 100644 index 00000000000000..82ff8fe2534a74 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/iter.py @@ -0,0 +1,203 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ....utils import BreakGraphError, FallbackError +from ..pycode_generator import PyCodeGen +from ..tracker import ConstTracker, DummyTracker +from .base import VariableBase +from .basic import ConstantVariable +from .container import ContainerVariable, TupleVariable + +if TYPE_CHECKING: + from ..function_graph import FunctionGraph + from ..tracker import Tracker + + +class IterVariable(VariableBase): + """ + This Variable (include subclasses) should be generated only when simulate GET_ITER opcode + """ + + def __init__( + self, obj: VariableBase, graph: FunctionGraph, tracker: Tracker + ): + super().__init__(graph, tracker) + self.hold = obj + + def make_stringify_guard(self): + return self.hold.make_stringify_guard() + + def next(self): + raise NotImplementedError(f"Can not simulate `next` for {type(self)}") + + def get_iter(self): + return self + + def get_hold(self): + return self.hold + + +class SequenceIterVariable(IterVariable): + """ + The basic SequenceIterVariable wraps iterators which can be simulated by call getitem + Currently includes: List | Tuple | Dict (keys) | Range | Tensor | nn.LayerList + """ + + mutable_attrs = ["idx"] + + def __init__(self, obj, graph: FunctionGraph, tracker: Tracker): + super().__init__(obj, graph, tracker) + self.idx = 0 + self.graph.side_effects.record_mutable_variable(self) + + def next(self): + # TODO: self.hold should have a __len__ method + if self.idx < len(self.hold): + val = self.hold[self.idx] + self.idx += 1 + return val + else: + raise StopIteration() + + def to_list(self) -> list: + if self.has_side_effect(): + raise FallbackError("Can not convert an used iterator into list") + self.idx = len(self.hold) + retval = [] + for i in range(len(self.hold)): + retval.append(self.hold[i]) + return retval + + def has_side_effect(self) -> bool: + return self.idx != 0 + + @property + def main_info(self) -> dict[str, Any]: + return { + "idx": self.idx, + } + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + self.hold.reconstruct(codegen) + codegen.gen_get_iter() + + +class EnumerateVariable(SequenceIterVariable): + """ + EnumerateVariable holds a SequenceIterVariable and return additional index + """ + + def __init__(self, val_iterator, graph, tracker): + super().__init__(val_iterator, graph, tracker) + + def next(self): + val = self.hold.next() + idx_var = ConstantVariable(self.idx, self.graph, ConstTracker(self.idx)) + self.idx += 1 + return TupleVariable( + (idx_var, val), self.graph, DummyTracker([idx_var, val]) + ) + + def to_list(self): + values = self.hold.to_list() + idx = [ + ConstantVariable(i, self.graph, ConstTracker(i)) + for i in range(len(values)) + ] + return list(zip(idx, values)) + + def has_side_effect(self) -> bool: + return self.hold.has_side_effect() or self.idx != 0 + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + codegen.gen_load_global("enumerate", push_null=True) + self.hold.reconstruct(codegen) + codegen.gen_call_function(1) + + def get_hold(self): + return self.hold.get_hold() + + @staticmethod + def from_iterator(value, graph: FunctionGraph | None, tracker: Tracker): + iter_variable = value.get_iter() + if isinstance(iter_variable, SequenceIterVariable): + return EnumerateVariable(iter_variable, graph, tracker) + else: + return UserDefinedIterVariable(value, graph, tracker) + + +class MapVariable(SequenceIterVariable): + """ + MapVariable holds a SequenceIterVariable and return a Iterable Variable after map function + """ + + def __init__(self, func, val_iterator, graph, tracker): + super().__init__(val_iterator, graph, tracker) + self.func = func + + def next(self): + return self.func(self.hold.next()) + + def to_list(self) -> list: + retval = [] + while True: + try: + retval.append(self.func(self.hold.next())) + except StopIteration: + break + return retval + + def has_side_effect(self) -> bool: + return self.hold.has_side_effect() + + def _reconstruct(self, codegen: PyCodeGen): + if self.has_side_effect(): + super()._reconstruct(codegen) + else: + codegen.gen_load_global("map", push_null=True) + self.func.reconstruct(codegen) + self.hold.reconstruct(codegen) + codegen.gen_call_function(2) + + @staticmethod + def from_iterator( + func, value, graph: FunctionGraph | None, tracker: Tracker + ): + iter_variable = ( + value.get_iter() if isinstance(value, ContainerVariable) else value + ) + + if isinstance(iter_variable, IterVariable): + return MapVariable(func, iter_variable, graph, tracker) + else: + return UserDefinedIterVariable(value, graph, tracker) + + +# what UserDefinedIterVariable holds doesn't matter, because use user defined iterator will trigger break graph +class UserDefinedIterVariable(IterVariable): + def __init__(self, obj, graph, tracker): + super().__init__(obj, graph, tracker) + + def next(self): + raise BreakGraphError("Break graph when using user defined iterator") diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py new file mode 100644 index 00000000000000..5fc71359e93868 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .instruction_utils import ( # noqa: F401 + Instruction, + calc_offset_from_bytecode_offset, + calc_stack_effect, + convert_instruction, + gen_instr, + get_instructions, + instrs_info, + modify_extended_args, + modify_instrs, + modify_vars, + relocate_jump_target, + replace_instr, + reset_offset, +) +from .opcode_analysis import ( # noqa: F401 + Space, + analysis_inputs, + analysis_used_names_with_space, +) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py new file mode 100644 index 00000000000000..182ba54279eeff --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -0,0 +1,407 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +import dis +import sys +from typing import TYPE_CHECKING, Any + +from ...utils import InnerError +from .opcode_info import ABS_JUMP, ALL_JUMP, REL_BWD_JUMP, REL_JUMP + +if TYPE_CHECKING: + import types + + +@dataclasses.dataclass +class Instruction: + opcode: int + opname: str + arg: int | None + argval: Any + offset: int | None = None + starts_line: int | None = None + is_jump_target: bool = False + jump_to: Instruction | None = None + is_generated: bool = True + + # for analys EXTENDED_ARG + first_ex_arg: Instruction | None = None + ex_arg_for: Instruction | None = None + + # used in modify_extended_args + def __hash__(self): + return id(self) + + +def gen_instr(name, arg=None, argval=None, gened=True, jump_to=None): + return Instruction( + opcode=dis.opmap[name], + opname=name, + arg=arg, + argval=argval, + is_generated=gened, + jump_to=jump_to, + ) + + +def convert_instruction(instr: dis.Instruction) -> Instruction: + """ + Converts a disassembled instruction to a customized Instruction object. + + Args: + instr (dis.Instruction): The disassembled instruction. + + Returns: + Instruction: A customized Instruction object. + """ + return Instruction( + instr.opcode, + instr.opname, + instr.arg, + instr.argval, + instr.offset, + instr.starts_line, + instr.is_jump_target, + jump_to=None, + is_generated=False, + ) + + +def get_instructions(code: types.CodeType) -> list[Instruction]: + """ + Returns parsed instructions from the given code object and exclude + any opcodes that contain `EXTENDED_ARG`. + + Args: + code (types.CodeType): The code object to extract instructions from. + + Returns: + list[Instruction]: A list of Instruction objects representing the + bytecode instructions in the code object. + """ + # instrs do not contain EXTENDED_ARG + instrs = list(map(convert_instruction, dis.get_instructions(code))) + for instr in instrs: + if instr.opname in ALL_JUMP: + origin_jump_target = calc_offset_from_bytecode_offset( + instr.argval, instrs + ) + jump_offset = origin_jump_target + + while instrs[jump_offset].opname == "EXTENDED_ARG": + jump_offset += 1 + + if origin_jump_target != jump_offset: + # copy infos from EXETENDED_ARG to other opcode + + if instrs[origin_jump_target].is_jump_target: + instrs[jump_offset].is_jump_target = instrs[ + origin_jump_target + ].is_jump_target + if instrs[origin_jump_target].starts_line: + instrs[jump_offset].starts_line = instrs[ + origin_jump_target + ].starts_line + + instr.jump_to = instrs[jump_offset] + + # if the origin opcode contains EXTENDED_ARG, it should be like: + # >> EXTENDED_ARG 1 + # XX 388 <- 256 + 132 + # filter all EXTENDED_ARG here + instrs = [x for x in instrs if x.opname != "EXTENDED_ARG"] + return instrs + + +def modify_instrs(instructions: list[Instruction]) -> None: + """ + Modifies the given list of instructions. It contains three steps: + + 1. reset offset + 2. relocate jump target + 3. add EXTENDED_ARG instruction if needed + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + None + """ + modify_completed = False + while not modify_completed: + reset_offset(instructions) + relocate_jump_target(instructions) + modify_completed = modify_extended_args(instructions) + + +def reset_offset(instructions: list[Instruction]) -> None: + """ + Resets the offset for each instruction in the list. + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + None + """ + from ..executor.pycode_generator import get_instruction_size + + if sys.version_info >= (3, 11): + current_offset = 0 + for instr in instructions: + instr.offset = current_offset + current_offset += get_instruction_size(instr) + return + for idx, instr in enumerate(instructions): + instr.offset = idx * 2 + + +def correct_jump_direction(instr: Instruction, arg: int) -> Instruction: + """ + Corrects the jump direction of the given instruction. + NOTE(zrr1999): In Python 3.11, JUMP_ABSOLUTE is removed, so python generates JUMP_FORWARD or JUMP_BACKWARD instead, + but in for loop breakgraph, we reuse JUMP_BACKWARD to jump forward, so we need to change it to JUMP_FORWARD. + + Args: + instr (Instruction): The instruction to be corrected. + """ + if instr.opname in ABS_JUMP: + instr.arg = arg + return instr + elif instr.opname in REL_JUMP: + if arg < 0: + if instr.opname in REL_BWD_JUMP: + forward_op_name = instr.opname.replace("BACKWARD", "FORWARD") + if forward_op_name not in dis.opmap: + raise InnerError(f"Unknown jump type {instr.opname}") + instr.opname = forward_op_name + instr.opcode = dis.opmap[forward_op_name] + else: # instr.opname in REL_FWD_JUMP + backward_op_name = instr.opname.replace("FORWARD", "BACKWARD") + if backward_op_name not in dis.opmap: + raise InnerError(f"Unknown jump type {instr.opname}") + instr.opname = backward_op_name + instr.opcode = dis.opmap[backward_op_name] + instr.arg = -arg + else: + instr.arg = arg + return instr + else: + raise ValueError(f"unknown jump type: {instr.opname}") + + +def relocate_jump_target(instructions: list[Instruction]) -> None: + """ + If a jump instruction is found, this function will adjust the jump targets based on the presence of EXTENDED_ARG instructions. + If an EXTENDED_ARG instruction exists for the jump target, use its offset as the new target. + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + None + """ + extended_arg = [] + for instr in instructions: + if instr.opname == "EXTENDED_ARG": + extended_arg.append(instr) + continue + + if instr.opname in ALL_JUMP: + assert instr.jump_to is not None + assert instr.offset is not None + # if jump target has extended_arg, should jump to the first extended_arg opcode + jump_target = ( + instr.jump_to.offset + if instr.jump_to.first_ex_arg is None + else instr.jump_to.first_ex_arg.offset + ) + assert jump_target is not None + + if instr.opname in ABS_JUMP: + new_arg = jump_target + else: # instr.opname in REL_JUMP + new_arg = jump_target - instr.offset - 2 + if instr.opname in REL_BWD_JUMP: + new_arg = -new_arg + + if sys.version_info >= (3, 10): + new_arg //= 2 + correct_jump_direction(instr, new_arg) + assert instr.arg is not None + if extended_arg: + instr.arg &= 0xFF + new_arg = new_arg >> 8 + for ex in reversed(extended_arg): + ex.arg = new_arg & 0xFF + new_arg = new_arg >> 8 + + # need more extended_args instr + # set arg in the first extended_arg + if new_arg > 0: + extended_arg[0].arg += new_arg << 8 + extended_arg.clear() + + +def modify_extended_args(instructions: list[Instruction]) -> bool: + """ + This function replaces any instruction with an argument greater than or equal to 256 with one or more EXTENDED_ARG instructions. + + Args: + instructions (list): The list of Instruction objects representing bytecode instructions. + + Returns: + bool: True if the modification is completed, False otherwise. + """ + + modify_completed = True + extend_args_record = {} + for instr in instructions: + if instr.arg and instr.arg >= 256: # more than one byte + _instrs = [ + instr + ] # replace instr with _instrs later (it is a set of instrs), all operations will be recorded in extend_args_record + val = instr.arg + instr.arg = val & 0xFF + val = val >> 8 + while val > 0: + _instrs.append(gen_instr("EXTENDED_ARG", arg=val & 0xFF)) + val = val >> 8 + + extend_args_record.update({instr: list(reversed(_instrs))}) + + if extend_args_record: + # if new EXTENDED_ARG inserted, we need update offset and jump target + modify_completed = False + + def bind_ex_arg_with_instr(ex_arg, instr): + # move opcode info to EXTENDED_ARG + ex_arg.starts_line = instr.starts_line + instr.starts_line = None + ex_arg.is_jump_target = instr.is_jump_target + instr.is_jump_target = False + + if instr.ex_arg_for is not None: + # instr is also an ex_arg for another instr + instr.ex_arg_for.first_ex_arg = ex_arg + ex_arg.ex_arg_for = instr.ex_arg_for + instr.ex_arg_for = None + else: + instr.first_ex_arg = ex_arg + ex_arg.ex_arg_for = instr + + for key, val in extend_args_record.items(): + bind_ex_arg_with_instr(val[0], key) + replace_instr(instructions, instr=key, new_instr=val) + + return modify_completed + + +def modify_vars(instructions, code_options): + co_names = code_options['co_names'] + co_varnames = code_options['co_varnames'] + co_freevars = code_options['co_freevars'] + for instrs in instructions: + if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST': + assert ( + instrs.argval in co_varnames + ), f"`{instrs.argval}` not in {co_varnames}" + instrs.arg = co_varnames.index(instrs.argval) + elif instrs.opname == "LOAD_DEREF" or instrs.opname == "STORE_DEREF": + if sys.version_info >= (3, 11): + namemap = co_varnames + co_freevars + assert ( + instrs.argval in namemap + ), f"`{instrs.argval}` not in {namemap}" + instrs.arg = namemap.index(instrs.argval) + + +def calc_offset_from_bytecode_offset( + bytecode_offset: int, + instructions: list[dis.Instruction] | list[Instruction], +) -> int: + """ + Calculate the index from bytecode offset, because it have 2 bytes per instruction (for Python <= 3.10). + + Args: + bytecode_offset (int): The bytecode offset of the instruction. + + Returns: + int: The index of the instruction in the instruction list. + """ + + if sys.version_info >= (3, 11): + instruction_offsets = [x.offset for x in instructions] + return instruction_offsets.index(bytecode_offset) + return bytecode_offset // 2 + + +def replace_instr(instructions, instr, new_instr): + idx = instructions.index(instr) + instructions[idx : idx + 1] = new_instr + + +def instrs_info(instrs, mark=None, range=None): + ret = [] + start = -1 + end = 1000000 + if mark is not None and range is not None: + start = mark - range + end = mark + range + 1 + for idx, instr in enumerate(instrs): + if idx < start or idx >= end: + continue + if instr.starts_line is not None: + ret.append("") + ret.append( + "{line:<8s}{is_jump_target:>2s}{offset:>4d} {opname:<30s}{arg:<4s}{argval:<40s}{mark}".format( + line=str(instr.starts_line) if instr.starts_line else "", + is_jump_target=">>" if instr.is_jump_target else " ", + offset=instr.offset + if instr.offset or instr.offset == 0 + else -1, + opname=instr.opname, + arg=str(instr.arg) if instr.arg is not None else "", + argval=f"({instr.argval})" if instr.argval else "", + mark="", + ) + ) + if idx == mark: + ret[-1] = "\033[31m" + ret[-1] + "\033[0m" + return ret + + +def calc_stack_effect(instr: Instruction, *, jump: bool | None = None) -> int: + """ + Gets the stack effect of the given instruction. In Python 3.11, the stack effect of `CALL` is -1, + refer to https://github.com/python/cpython/blob/3.11/Python/compile.c#L1123-L1124. + + Args: + instr: The instruction. + + Returns: + The stack effect of the instruction. + + """ + if sys.version_info[:2] == (3, 11): + if instr.opname == "PRECALL": + return 0 + elif instr.opname == "CALL": + # NOTE(zrr1999): push_n = 1, pop_n = oparg + 2, stack_effect = push_n - pop_n = -oparg-1 + assert instr.arg is not None + return -instr.arg - 1 + return dis.stack_effect(instr.opcode, instr.arg, jump=jump) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py new file mode 100644 index 00000000000000..dcda7558e5a395 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +from enum import Enum + +from ...utils import InnerError, OrderedSet +from .instruction_utils import Instruction +from .opcode_info import ALL_JUMP, HAS_FREE, HAS_LOCAL, UNCONDITIONAL_JUMP + + +@dataclasses.dataclass +class State: + reads: OrderedSet[str] + writes: OrderedSet[str] + visited: OrderedSet[int] + + +def is_read_opcode(opname): + if opname in [ + "LOAD_FAST", + "LOAD_DEREF", + "LOAD_NAME", + "LOAD_GLOBAL", + "LOAD_CLOSURE", + ]: + return True + if opname in ( + "DELETE_FAST", + "DELETE_DEREF", + "DELETE_NAME", + "DELETE_GLOBAL", + ): + return True + return False + + +def is_write_opcode(opname): + if opname in ["STORE_FAST", "STORE_NAME", "STORE_DEREF", "STORE_GLOBAL"]: + return True + if opname in ( + "DELETE_FAST", + "DELETE_DEREF", + "DELETE_NAME", + "DELETE_GLOBAL", + ): + return True + return False + + +def analysis_inputs( + instructions: list[Instruction], + current_instr_idx: int, + stop_instr_idx: int | None = None, +) -> OrderedSet[str]: + """ + Analyze the inputs of the instructions from current_instr_idx to stop_instr_idx. + + Args: + instructions (list[Instruction]): The instructions to analyze. + current_instr_idx (int): The index of the current instruction. + stop_instr_idx (int | None, optional): The index of the instruction to stop. Defaults to None. + If None, the analysis will stop at the end of the instructions. + + Returns: + set[str]: The analysis result. + """ + root_state = State(OrderedSet(), OrderedSet(), OrderedSet()) + + def fork( + state: State, start: int, jump: bool, jump_target: int + ) -> OrderedSet[str]: + new_start = start + 1 if not jump else jump_target + new_state = State( + OrderedSet(state.reads), + OrderedSet(state.writes), + OrderedSet(state.visited), + ) + return walk(new_state, new_start) + + def walk(state: State, start: int) -> OrderedSet[str]: + end = len(instructions) if stop_instr_idx is None else stop_instr_idx + for i in range(start, end): + if i in state.visited: + return state.reads + state.visited.add(i) + + instr = instructions[i] + if instr.opname in HAS_LOCAL | HAS_FREE: + if is_read_opcode(instr.opname) and instr.argval not in ( + state.writes + ): + state.reads.add(instr.argval) + elif is_write_opcode(instr.opname): + state.writes.add(instr.argval) + elif instr.opname in ALL_JUMP: + assert instr.jump_to is not None + target_idx = instructions.index(instr.jump_to) + # Fork to two branches, jump or not + jump_branch = fork(state, i, True, target_idx) + not_jump_branch = ( + fork(state, i, False, target_idx) + if instr.opname not in UNCONDITIONAL_JUMP + else OrderedSet() + ) + return jump_branch | not_jump_branch + elif instr.opname == "RETURN_VALUE": + return state.reads + return state.reads + + return walk(root_state, current_instr_idx) + + +@dataclasses.dataclass +class SpaceState: + reads: dict[str, Space] + writes: dict[str, Space] + visited: OrderedSet[int] + + def __or__(self, other): + reads = {} + reads.update(other.reads) + reads.update(self.reads) + writes = {} + writes.update(other.writes) + writes.update(self.writes) + return SpaceState(reads, writes, OrderedSet()) + + +class Space(Enum): + locals = 1 + globals = 2 + cells = 3 + all = 4 + + +def get_space(opname: str): + if "FAST" in opname: + return Space.locals + elif "GLOBAL" in opname: + return Space.globals + elif "DEREF" in opname or "CLOSURE" in opname: + return Space.cells + elif "NAME" in opname: + return Space.all + else: + raise InnerError(f"Unknown space for {opname}") + + +def analysis_used_names_with_space( + instructions: list[Instruction], + start_instr_idx: int, + stop_instr_idx: int | None = None, +): + root_state = SpaceState({}, {}, OrderedSet()) + + def fork( + state: SpaceState, start: int, jump: bool, jump_target: int + ) -> SpaceState: + new_start = start + 1 if not jump else jump_target + new_state = SpaceState( + dict(state.reads), + dict(state.writes), + OrderedSet(state.visited), + ) + return walk(new_state, new_start) + + def walk(state: SpaceState, start: int) -> SpaceState: + end = len(instructions) if stop_instr_idx is None else stop_instr_idx + for i in range(start, end): + if i in state.visited: + return state + state.visited.add(i) + + instr = instructions[i] + if instr.opname in HAS_LOCAL | HAS_FREE: + if is_read_opcode(instr.opname) and instr.argval not in ( + state.writes + ): + space = get_space(instr.opname) + state.reads[instr.argval] = space + elif is_write_opcode(instr.opname): + space = get_space(instr.opname) + state.writes[instr.argval] = space + elif instr.opname in ALL_JUMP: + assert instr.jump_to is not None + target_idx = instructions.index(instr.jump_to) + # Fork to two branches, jump or not + jump_branch = fork(state, i, True, target_idx) + not_jump_branch = ( + fork(state, i, False, target_idx) + if instr.opname not in UNCONDITIONAL_JUMP + else SpaceState({}, {}, OrderedSet()) + ) + return jump_branch | not_jump_branch + elif instr.opname == "RETURN_VALUE": + return state + return state + + state = walk(root_state, start_instr_idx) + all_used_vars = {} + all_used_vars.update(state.writes) + all_used_vars.update(state.reads) + return all_used_vars diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py new file mode 100644 index 00000000000000..cc63d5ecde967a --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from enum import Enum + +import opcode + +REL_JUMP = {opcode.opname[x] for x in opcode.hasjrel} +REL_BWD_JUMP = {opname for opname in REL_JUMP if "BACKWARD" in opname} +REL_FWD_JUMP = REL_JUMP - REL_BWD_JUMP +ABS_JUMP = {opcode.opname[x] for x in opcode.hasjabs} +HAS_LOCAL = {opcode.opname[x] for x in opcode.haslocal} +HAS_FREE = {opcode.opname[x] for x in opcode.hasfree} +ALL_JUMP = REL_JUMP | ABS_JUMP +UNCONDITIONAL_JUMP = {"JUMP_ABSOLUTE", "JUMP_FORWARD"} +if sys.version_info >= (3, 11): + UNCONDITIONAL_JUMP.add("JUMP_BACKWARD") + + +class JumpDirection(Enum): + FORWARD = "FORWARD" + BACKWARD = "BACKWARD" + + +class PopJumpCond(Enum): + FALSE = "FALSE" + TRUE = "TRUE" + NONE = "NONE" + NOT_NONE = "NOT_NONE" + + +# Cache for some opcodes, it's for Python 3.11+ +# https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53 +PYOPCODE_CACHE_SIZE = { + "BINARY_SUBSCR": 4, + "STORE_SUBSCR": 1, + "UNPACK_SEQUENCE": 1, + "STORE_ATTR": 4, + "LOAD_ATTR": 4, + "COMPARE_OP": 2, + "LOAD_GLOBAL": 5, + "BINARY_OP": 1, + "LOAD_METHOD": 10, + "PRECALL": 1, + "CALL": 4, +} diff --git a/python/paddle/jit/sot/opcode_translator/skip_files.py b/python/paddle/jit/sot/opcode_translator/skip_files.py new file mode 100644 index 00000000000000..5d5d04e56eca91 --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/skip_files.py @@ -0,0 +1,180 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import codecs +import collections +import contextlib +import copy +import copyreg +import dataclasses +import enum +import functools +import importlib +import inspect +import linecache +import logging +import multiprocessing +import operator +import os +import posixpath +import random +import re +import selectors +import signal +import sys +import tempfile +import threading +import tokenize +import traceback +import types +import typing +import unittest +import uuid +import warnings +import weakref + +import _collections_abc +import _weakrefset +import decorator +import google.protobuf +import numpy +import setuptools + +import paddle + +from ..utils import log + +NEED_SKIP_THIRD_PARTIY_MODULES = { + abc, + collections, + contextlib, + copy, + copyreg, + dataclasses, + enum, + functools, + google.protobuf, + importlib, + inspect, + linecache, + logging, + multiprocessing, + numpy, + operator, + os, + posixpath, + random, + re, + selectors, + signal, + tempfile, + threading, + tokenize, + traceback, + types, + typing, + unittest, + weakref, + _collections_abc, + _weakrefset, + decorator, + codecs, + uuid, + setuptools, + warnings, +} + +if sys.version_info < (3, 11): + import sre_compile + import sre_parse + + NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_compile) + NEED_SKIP_THIRD_PARTIY_MODULES.add(sre_parse) + +if sys.version_info < (3, 12): + import distutils + + NEED_SKIP_THIRD_PARTIY_MODULES.add(distutils) + + +def _strip_init_py(s): + return re.sub(r"__init__.py$", "", s) + + +def _module_dir(m: types.ModuleType): + return _strip_init_py(m.__file__) + + +skip_file_names = {_module_dir(m) for m in NEED_SKIP_THIRD_PARTIY_MODULES} + + +sot_path = os.path.dirname(__file__).rpartition(os.sep)[0] + os.sep +paddle_path = sys.modules["paddle"].__file__.rpartition(os.sep)[0] + os.sep + +skip_file_names.add(sot_path) +skip_file_names.add(paddle_path) +skip_file_names.add( + "<frozen importlib", +) +skip_file_names.add("<__array_function__ internals>") + +skip_file_name_re = re.compile( + f"^({'|'.join(map(re.escape, skip_file_names))})" +) + +customed_skip_code = set() + +no_skip_code = {paddle.nn.Sequential.forward.__code__} + + +def need_skip_path(filepath: str) -> bool: + """ + Check if the file should be skipped and not transcribed. + + Args: + filepath: The path of the file to check. + + Returns: + bool: True if the file should be skipped. + """ + if not filepath.startswith("<"): + filepath = os.path.abspath(filepath) + return bool(skip_file_name_re.match(filepath)) + + +def skip_function(function): + customed_skip_code.add(function.__code__) + return function + + +def need_skip(frame): + pycode = frame.f_code + if pycode in no_skip_code: + return False + if pycode in customed_skip_code: + log(3, f"Skip frame by code: {pycode}\n") + return True + filename = pycode.co_filename + if sys.version_info >= (3, 11) and filename.startswith("<frozen"): + # NOTE(SigureMo): In Python 3.11, the core modules essential for + # Python startup are “frozen”. So we need get original filename from + # frame. + # see https://docs.python.org/3/whatsnew/3.11.html#faster-startup for more details. + # This workaround is refer to pdb.py + # https://github.com/python/cpython/blob/3.11/Lib/pdb.py#L1328-L1331 + _filename = frame.f_globals.get('__file__', None) + if isinstance(_filename, str): + filename = _filename + return need_skip_path(filename) diff --git a/python/paddle/jit/sot/opcode_translator/transform.py b/python/paddle/jit/sot/opcode_translator/transform.py new file mode 100644 index 00000000000000..8fcf3cc5a2b72d --- /dev/null +++ b/python/paddle/jit/sot/opcode_translator/transform.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dis +import sys +from functools import partial + +from ..profiler import EventGuard +from ..utils import CodeStatus, log, log_do +from .custom_code import CustomCode +from .executor.executor_cache import OpcodeExecutorCache +from .skip_files import need_skip + + +def print_locals(frame): + local_key = [ + key for key in frame.f_locals.keys() if not key.startswith("__") + ] + print( + f"[eval_frame_callback] {frame.f_code.co_name} with locals {local_key}" + ) + print( + f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} with cellvars + freevars: {frame.f_code.co_cellvars + frame.f_code.co_freevars}" + ) + + def convert_obj(obj): + import paddle + + if isinstance(obj, paddle.Tensor): + return "Tensor(" + str(obj.shape) + ")" + if isinstance(obj, list): + return [convert_obj(i) for i in obj] + return obj + + for key in local_key: + print( + f"[eval_frame_callback] {' ' * len(frame.f_code.co_name)} {key} = {convert_obj(frame.f_locals[key])}" + ) + + +def eval_frame_callback(frame, **kwargs) -> CustomCode: + with EventGuard( + f"eval_frame_callback: {frame.f_code.co_name}", event_level=2 + ): + # is generator + if frame.f_code.co_flags & 0x20 > 0: + return CustomCode(None, True) + + # NOTE(SigureMo): Temporary fallback when code has exception handling. + if sys.version_info >= (3, 11) and frame.f_code.co_exceptiontable: + log( + 3, + f"[eval_frame_callback] {frame.f_code} has co_exceptiontable\n", + ) + return CustomCode(None, False) + + if need_skip(frame): + log(3, f"[eval_frame_callback] skip {frame.f_code}\n") + custom_code = CustomCode(None, False) + new_code = frame.f_code + else: + log( + 2, f"[eval_frame_callback] start to translate: {frame.f_code}\n" + ) + log_do(4, partial(print_locals, frame)) + + log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n") + log_do(3, lambda: dis.dis(frame.f_code)) + + custom_code = OpcodeExecutorCache()(frame, **kwargs) + + if custom_code.code is None: + log( + 3, + "[transform] NewCode (same as origin code): " + + frame.f_code.co_name + + "\n", + ) + new_code = frame.f_code + else: + log( + 3, + "[transform] NewCode: " + custom_code.code.co_name + "\n", + ) + log_do(3, lambda: dis.dis(custom_code.code)) + new_code = custom_code.code + + # just check those codes which need open eval_frame + if ( + custom_code.disable_eval_frame is False + and CodeStatus().is_code_without_graph(new_code) + ): + log( + 3, + "[eval_frame_callback] Code has no graph, block it.\n", + ) + return CustomCode(None, True) + + return custom_code diff --git a/python/paddle/jit/sot/profiler.py b/python/paddle/jit/sot/profiler.py new file mode 100644 index 00000000000000..8315e03dd37f5c --- /dev/null +++ b/python/paddle/jit/sot/profiler.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager +from functools import wraps + +from paddle.framework import core + +_event_level = int(os.environ.get("EVENT_LEVEL", "-1")) + + +class SotProfiler: + def __enter__(self): + self.enable() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + def enable(self, tag=None): + core.nvprof_start() + core.nvprof_enable_record_event() + + def disable(self): + core.nvprof_stop() + + +@contextmanager +def EventGuard(event_name, event_level=0): + try: + global _event_level + need_pop = False + if _event_level >= event_level: + core.nvprof_nvtx_push(event_name) + need_pop = True + yield + finally: + if need_pop: + core.nvprof_nvtx_pop() + + +if _event_level == -1: + + @contextmanager + def _EmptyEventGuard(event_name, event_level=0): + yield + + EventGuard = _EmptyEventGuard # noqa: F811 + + +def event_register(event_name, event_level=0): + def event_wrapper(func): + @wraps(func) + def call_with_event(*args, **kwargs): + with EventGuard(event_name, event_level=0): + return func(*args, **kwargs) + + return call_with_event + + def do_nothing(func): + return func + + global _event_level + if _event_level >= event_level: + return event_wrapper + else: + return do_nothing diff --git a/python/paddle/jit/sot/psdb.py b/python/paddle/jit/sot/psdb.py new file mode 100644 index 00000000000000..38fa4d7479e160 --- /dev/null +++ b/python/paddle/jit/sot/psdb.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +import types +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from typing import TypeVar + + from typing_extensions import ParamSpec + + T = TypeVar("T") + P = ParamSpec("P") + +NO_BREAKGRAPH_CODES: set[types.CodeType] = set() +NO_FALLBACK_CODES: set[types.CodeType] = set() + + +def assert_true(input: bool): + assert input + + +def print(*args, **kwargs): + builtins.print("[Dygraph]", *args, **kwargs) + + +def breakpoint(): + import paddle + + old = paddle.framework.core.set_eval_frame(None) + builtins.breakpoint() + paddle.framework.core.set_eval_frame(old) + + +def check_no_breakgraph(fn: Callable[P, T]) -> Callable[P, T]: + NO_BREAKGRAPH_CODES.add(fn.__code__) + return fn + + +def breakgraph(): + pass + + +def check_no_fallback(fn: Callable[P, T]) -> Callable[P, T]: + NO_FALLBACK_CODES.add(fn.__code__) + return fn + + +def fallback(): + pass + + +def in_sot(): + return False diff --git a/python/paddle/jit/sot/symbolic/compile_cache.py b/python/paddle/jit/sot/symbolic/compile_cache.py new file mode 100644 index 00000000000000..8fa7444ff06841 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/compile_cache.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle + +from ..profiler import EventGuard +from ..utils import ( + Cache, + CodeStatus, + GraphLogger, + Singleton, + StepInfoManager, + log_do, +) +from .interpreter import compile_sir + +if TYPE_CHECKING: + from .symbolic_context import SymbolicTraceContext + + +def clear_eager_tensor_name(output_tensors): + for output_tensor in output_tensors: + output_tensor.name = "" + + +class FallbackWrapper: + """ + Used to store and call static graph methods generated by paddle.jit.to_static + """ + + def __init__(self, compiled_fn, SIR): + self.compiled_fn = compiled_fn + self.partial_program = None + self.concrete_program = None + self.SIR = SIR # for debug + + def __call__(self, *args, **kwargs): + with EventGuard(f"FallbackWrapper: {self.SIR.name}"): + if StepInfoManager().need_back_trace: + CodeStatus().trace_back_frames() + + log_do( + 2, + lambda: print("[FallbackWrapper] start run SIR: \n", self.SIR), + ) + log_do( + 4, + lambda: print( + self.compiled_fn.get_concrete_program(*args, **kwargs)[ + 1 + ].train_program + ), + ) + if self.partial_program is None: + with EventGuard("FallbackWrapper: call compiled_fn"): + outputs = self.compiled_fn(*args, **kwargs) + ( + self.concrete_program, + self.partial_program, + ) = self.compiled_fn.get_concrete_program(*args, **kwargs) + else: + # Speed up Resnet from 0.0068 --> 0.0057 + with EventGuard("FallbackWrapper: call partial_program"): + outputs = self.partial_program(*args, **kwargs) + + clear_eager_tensor_name(outputs) + log_do( + 1, + lambda: GraphLogger().add_subgraph( + self.concrete_program.main_program + ), + ) + log_do( + 4, + lambda: print("[CompileCache] run sir forward success."), + ) + return outputs + + +@Singleton +class CompileSIRCache(Cache): + """ + Cache the compiled function of SIR + """ + + def __init__(self): + super().__init__(weak=False) + + def key_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): + """ + generate a hash key for a SIR + + Args: + context: The context to compile + sir_name: The name of the sir to compile + build_strategy: The build strategy to compile + + Returns: + The hash key of the SIR + """ + sir = context.get_sir(sir_name) + # NOTE(dev): Is str(sir) a heavy opearation ? + hash_key = hash(str(sir)) + return hash_key + + def value_fn(self, context: SymbolicTraceContext, sir_name: str, **kwargs): + """ + Generate static graph function + + Args: + context: The context to compile + sir_name: The name of the sir to compile + build_strategy: The build strategy to compile + + Returns: + The static graph function + """ + build_strategy = kwargs.get("build_strategy", None) + backend = kwargs.get("backend", None) + return FallbackWrapper( + paddle.jit.to_static( + compile_sir(context, sir_name), + build_strategy=build_strategy, + backend=backend, + enable_fallback=False, + ), + context.get_sir(sir_name), + ) diff --git a/python/paddle/jit/sot/symbolic/interpreter.py b/python/paddle/jit/sot/symbolic/interpreter.py new file mode 100644 index 00000000000000..13265bbab4e380 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/interpreter.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import paddle +from paddle.utils import to_sequence + +from ..utils import InnerError, map_if, map_if_extend +from .statement_ir import SIRRuntimeCache, Symbol + +if TYPE_CHECKING: + from .statement_ir import Statement, StatementIR + from .symbolic_context import SymbolicTraceContext + + +def replace_symbol( + values: list[Symbol] | list[object], state: dict[str, Symbol] +): + """ + Replaces Symbol objects with their corresponding values. + + Args: + values: A list of values that may contain Symbol objects. + state: A dict mapping Symbol names to their corresponding values. + + Returns: + A new list with Symbol objects replaced by their corresponding values in the state dict. + """ + # deal with list / map etc. + values = map_if_extend( + values, + pred=lambda x: isinstance(x, Symbol), + true_fn=lambda x: state[x.name], + false_fn=lambda x: x, + ) + return values + + +def _append_opstack_between(start, end, stack): + # NOTE(xiongkun): we don't sync for speed. careful!! + # [start, end) + from paddle.framework import core + + op_maker = core.op_proto_and_checker_maker + callstack_attr_name = op_maker.kOpCreationCallstackAttrName() + for op in for_each_ops_between(start, end): + op._set_attr(callstack_attr_name, stack) + + +def for_each_ops_between(start, end): + # NOTE(xiongkun): we don't sync for speed. careful!! + # [start, end) + program = paddle.static.default_main_program() + ops = program.current_block().ops[start:end] + yield from ops + + +def opnum_in_program(): + # NOTE(xiongkun): we don't sync for speed. careful!! + program = paddle.static.default_main_program() + return len(program.current_block().ops) + + +class Interpreter: + """ + Interpreter is used to interpret and execute SIR. + """ + + def __init__(self, symbolic_context: SymbolicTraceContext): + self._context = symbolic_context + + def get_sir(self, name: str) -> StatementIR: + """ + Returns the StatementIR object by given name. + + Args: + name: The name of the StatementIR. + + Returns: + The StatementIR object with the given name. + """ + return self._context.get_sir(name) + + def run_sir(self, name: str, state: dict[str, Symbol]): + """ + Runs the StatementIR with the given name using the provided state. + + Args: + name: The name of the given StatementIR to run. + state: A dict mapping Symbol names to their corresponding values. + + Returns: + A list of the Symbol of the StatementIR after execution. + """ + SIR = self.get_sir(name) + for stmt in SIR.statements: + stmt: Statement + before_stmt_opnum = opnum_in_program() + inputs = replace_symbol(stmt.inputs, state) + outs = getattr(self, stmt.type)(stmt, inputs) + + def _set(v, s): + state[s.name] = v + + if len(to_sequence(outs)) != len(to_sequence(stmt.outputs)): + raise InnerError("Number output mismatch, some error happen.") + + _append_opstack_between( + before_stmt_opnum, opnum_in_program() + 1, stmt.stmt_stack + ) + + map_if( + outs, + stmt.outputs, + pred=lambda v, s: isinstance(s, Symbol), + true_fn=lambda v, s: _set(v, s), + false_fn=lambda v, s: None, + ) + # fetch outputs + return replace_symbol(SIR.outputs, state) + + def call(self, stmt: Statement, inputs): + SIR = self.get_sir(stmt.sir_name) + state = prepare_state(SIR, inputs) + return self.run_sir(stmt.sir_name, state) + + def api(self, stmt, inputs): + args, kwargs = inputs + return stmt.api(*args, **kwargs) + + def method(self, stmt, inputs): + args, kwargs = inputs + var = args[0] + return getattr(var, stmt.method)(*args[1:], **kwargs) + + def layer(self, stmt, inputs): + args, kwargs = inputs + layer = stmt.layer() + assert layer is not None, "SIR bound layer is None." + return layer(*args, **kwargs) + + +def compile_sir(context: SymbolicTraceContext, name: str): + """ + Compile a SIR to a new function + + Args: + context: The context to compile + name: The name of the sir to compile + + """ + + @paddle.jit.not_to_static + def wrapper(args): + """ + This function will be decorated by paddle.to_static. + so the args is variables, not eager tensors. + """ + interpreter = Interpreter(context) + SIR = interpreter.get_sir(name) + state = prepare_state(SIR, args) + return interpreter.run_sir(name, state) + + return wrapper + + +def prepare_state(SIR, inputs): + state = {} + + # update free vars if exsits + if SIRRuntimeCache().has_key(SIR.name): # noqa: W601 + free_var_seeker = SIRRuntimeCache().get_free_vars(SIR.name) + if free_var_seeker: + state = free_var_seeker() + + # bind inputs + for sir_inp, inp in zip(SIR.inputs, inputs): + state[sir_inp.name] = inp + + return state diff --git a/python/paddle/jit/sot/symbolic/statement_ir.py b/python/paddle/jit/sot/symbolic/statement_ir.py new file mode 100644 index 00000000000000..11a08f36acd9d0 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/statement_ir.py @@ -0,0 +1,338 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +THIS FILE IS PRIVATE !! + +use interface in symbolic_context.py first. +""" +from __future__ import annotations + +import weakref +from typing import Any, Callable + +import paddle +from paddle.utils import is_sequence, map_structure + +from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend + + +class Symbol: + """ + Symbol is used to distinguish a string and a `math variable`. + """ + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return self.name + + def __repr__(self): + return str(self) + + def __eq__(self, other): + if isinstance(other, str): + return self.name == other + return self.name == other.name + + def __hash__(self): + return hash(self.name) + + def __deepcopy__(self, memo=None): + return Symbol(self.name) + + +class Statement: + """ + Statement is used to represent a sentence of code for building the neural network model, + which has four types: "call", "api", "method", and "layer". + + Note: + Statement temporarily does not support control flow. + """ + + def __init__( + self, + type: str, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + assert type in ["call", "api", "method", "layer"] + self.name = name + self.inputs = inputs # (list of Symbols, dict of Symbols) + self.outputs = outputs # list of Symbol | PythonObj + self.stmt_stack = ( + stacks # a list of string to record the source code callstack. + ) + self.type = type + + def __str__(self): + def to_string(inps): + if isinstance(inps, str) or not is_sequence(inps): + return inps.__str__() + inps = (x.__str__() for x in inps) + return ", ".join(inps) + + return "{} || {} = {} ({}) ".format( + self.type + " " * (10 - len(self.type)), + to_string(self.outputs), + self.name, + to_string(self.inputs), + ) + + def __repr__(self): + return self.__str__() + + +class CallStatement(Statement): + def __init__( + self, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__("call", name, inputs, outputs, stacks) + self.sir_name = name + + +class ApiStatement(Statement): + def __init__( + self, + api: Callable, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__( + "api", "paddle." + api.__name__, inputs, outputs, stacks + ) + self.api = api + + +class MethodStatement(Statement): + def __init__( + self, + name: str, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__("method", name, inputs, outputs, stacks) + self.method = name + + +class LayerStatement(Statement): + def __init__( + self, + layer: paddle.nn.Layer, + inputs: list[Symbol], + outputs: list[Symbol], + stacks: list[str], + ): + super().__init__( + "layer", layer.__class__.__name__, inputs, outputs, stacks + ) + self.layer = weakref.ref(layer) + + +class StatementIR: + """ + StatementIR is the carrier that records the code for building the neural network model.It is + a representation of a purely computational structure, and does not care about specific values. + The function converted from StatementIR can ensure that it can be turned into a static state. + In this way, we can reuse the original `to_static` function to realize the execution of the static graph. + + Note: + Don't create by yourself, just use the StatementIRCache.get() + """ + + def __init__(self, name: str): + self.name = name + self.inputs = [] # list of Symbol | PythonObj + self.outputs = [] # list of Symbol | PythonObj + self.statements = [] # list of Statement + + def __len__(self): + return len(self.statements) + + def __deepcopy__(self, memo=None): + new_sir = StatementIR(self.name) + new_sir.inputs = list(self.inputs) + new_sir.outputs = list(self.outputs) + new_sir.statements = list(self.statements) + return new_sir + + def add_input(self, input): + self.inputs.append(input) + + def add_output(self, output): + self.outputs.append(output) + + def add_statement(self, statement): + assert isinstance(statement, Statement) + self.statements.append(statement) + + def analyse_inputs(self): + used_symbols = OrderedSet() + generated_symbols = OrderedSet() + for stmt in self.statements: + for inp in flatten_extend(stmt.inputs): + if isinstance(inp, Symbol) and inp not in generated_symbols: + used_symbols.add(inp) + for out in flatten_extend(stmt.outputs): + if isinstance(out, Symbol): + generated_symbols.add(out) + + input_symbols = sorted(used_symbols, key=lambda x: x.name) + return input_symbols + + def __str__(self): + strs = [] + strs.append("StatmentIR: %s" % self.name) + strs.append(f" inputs: {map_structure(lambda x: x.name, self.inputs)}") + strs.append( + f" outputs: {map_structure(lambda x: x.name, self.outputs)}" + ) + strs.append(" statements: ") + for stmt in self.statements: + strs.append(f" {stmt}") + return "\n".join(strs) + + def __repr__(self): + return self.__str__() + + def graph_size(self): + call_layers = [x for x in self.statements if x.type == "layer"] + return len(self.statements) + len(call_layers) + + +@Singleton +class StatementIRFactory: + """ + It is used to create a StatementIR. + """ + + def __init__(self): + self.cache = {} + self.name_generator = NameGenerator("SIR_") + + def __getitem__(self, key): + return self.cache[key] + + def create(self, input_name=None): + if input_name: + name = input_name + else: + name = self.name_generator.next() + + sir = StatementIR(name) + self.cache[name] = sir + return sir + + def update(self, stmt_ir): + name = stmt_ir.name + self.cache[name] = stmt_ir + + def clear(self): + want_clear = [ + key + for key in self.cache.keys() + if self.name_generator.match_name(key) + ] + for key in want_clear: + del self.cache[key] + + +@Singleton +class SIRRuntimeCache: + """ + It is used to cache the runtime information of the StatementIR. + """ + + def __init__(self): + self.cache = {} + # { name : (inputs, outputs, free_vars) } + # inputs : can be used when call_SIR, if free_vars exist + # outputs : used for generator new ProxyTensor output before fallback + # free_vars: (name, function) + + def __getitem__(self, key): + return self.cache[key] + + def has_key(self, key: str) -> bool: + """ + has_key is used to check whether the key is in the cache. + """ + return key in self.cache.keys() + + def set_origin_inputs(self, key: str, inputs: Any): + """ + Set Cache origin Inputs of the StatementIR + """ + if key in self.cache.keys(): + val = self.cache[key] + self.cache[key] = (inputs, val[1], val[2]) + else: + self.cache[key] = (inputs, None, None) + + def set_origin_outputs(self, key: str, outputs: Any): + """ + Set Cache origin outputs of the StatementIR + """ + if key in self.cache.keys(): + val = self.cache[key] + self.cache[key] = (val[0], outputs, val[2]) + else: + self.cache[key] = (None, outputs, None) + + def set_free_vars(self, key: str, free_vars: Any): + """ + Set Cache free variables of the StatementIR + """ + if key in self.cache.keys(): + val = self.cache[key] + self.cache[key] = (val[0], val[1], free_vars) + else: + self.cache[key] = (None, None, free_vars) + + def get_origin_inputs(self, key: str): + """ + Get the origin inputs of the StatementIR. + """ + if key in self.cache.keys(): + return self.cache[key][0] + else: + return None + + def get_origin_outputs(self, key: str): + """ + Get the origin outputs of the StatementIR. + """ + if key in self.cache.keys(): + return self.cache[key][1] + else: + return None + + def get_free_vars(self, key: str): + """ + Get the free variables of the StatementIR. + """ + if key in self.cache.keys(): + return self.cache[key][2] + else: + return None diff --git a/python/paddle/jit/sot/symbolic/symbolic_context.py b/python/paddle/jit/sot/symbolic/symbolic_context.py new file mode 100644 index 00000000000000..47f40bbcc9ec74 --- /dev/null +++ b/python/paddle/jit/sot/symbolic/symbolic_context.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from ..utils import log +from .compile_cache import CompileSIRCache +from .statement_ir import ( + ApiStatement, + CallStatement, + LayerStatement, + MethodStatement, + StatementIR, + StatementIRFactory, + Symbol, +) + + +class SymbolicTraceContext: + """ + SymbolicTraceContext is a context manager, which is used to record the symbolic trace. + + """ + + def __init__(self): + self.reset() + + def reset(self): + """ + Reset the context. + """ + + # TODO(dev): StatementIRFactory is a singleton, but SymbolicTraceContext is not. + # whether will two different SymbolicTraceContext objects be conflict ? + self.statement_factory = StatementIRFactory() + self.sir_stack = [self.statement_factory.create()] + + @property + def TOS(self): + """ + The top SIR of sir_stack. + + Returns: + StatementIR: the top of stack. + """ + + return self.sir_stack[-1] + + def call_SIR(self, sirname, inputs, outputs, stacks): + """ + Call a SIR, which is a subgraph. + """ + + stmt = CallStatement(sirname, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def call_API(self, api, inputs, outputs, stacks): + """ + Call a paddle api. + """ + + assert callable(api), "call_API must receive a paddle api." + stmt = ApiStatement(api, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def call_METHOD(self, method_name, inputs, outputs, stacks): + """ + Call a method of a api. The API here can be python or Paddle + """ + assert isinstance( + method_name, str + ), "call_METHOD must method api name. string." + assert isinstance( + inputs[0][0], Symbol + ), "call_METHOD must first augument must be Symbol Variable." + stmt = MethodStatement(method_name, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def call_LAYER(self, layer, inputs, outputs, stacks): + """ + Call a layer of a api. + """ + stmt = LayerStatement(layer, inputs, outputs, stacks) + self.TOS.add_statement(stmt) + + def get_sir(self, name: str): + """ + Get a SIR from statement_factory. + + Args: + name (str): the name of SIR. + + Returns: + StatementIR: the SIR. + """ + return self.statement_factory[name] + + def reset_TOS(self): + """ + Reset the TOS. + """ + self.sir_stack.pop() + self.sir_stack.append(self.statement_factory.create()) + + def replace_TOS(self, sir): + """ + Use deepcopyed sir to replace the TOS. + This function will update statment_factory. + """ + self.sir_stack.pop() + self.sir_stack.append(sir) + self.statement_factory.update(sir) + + def compile_do_nothing(self, ret_vals): + """ + Return a dummy function, which will return an empty list. + + Args: + ret_vals (list[Symbol]): the return values of the function. + """ + + def dummy_func(*args, **kwargs): + return [] + + # return None function + dummy_stmt_ir = StatementIR("dummy_func") + dummy_stmt_ir.outputs = [] + dummy_stmt_ir.inputs = [] + return dummy_func, dummy_stmt_ir + + def compile_fn(self, ret_vals, **kwargs): + """ + start compile and return the python function, which must can be to_static without errors. + """ + cur_sir: StatementIR = self.TOS + # step0: if no statement, return a dummy function + if len(cur_sir.statements) == 0: + return self.compile_do_nothing(ret_vals) + # step1: analyse sir inputs and outputs + cur_sir.inputs = cur_sir.analyse_inputs() + # TODO: output analysis + cur_sir.outputs = ret_vals + log(2, "start subgraph compile and execution.\n") + log(2, self.TOS, "\n") + # step2: call compile_sir and get python function, third cache is triggered here. + static_func = CompileSIRCache()(self, cur_sir.name, **kwargs) + # step3: GC and reset TOS + # self.reset_TOS() + + return static_func, cur_sir diff --git a/python/paddle/jit/sot/translate.py b/python/paddle/jit/sot/translate.py new file mode 100644 index 00000000000000..88f569460a5ca0 --- /dev/null +++ b/python/paddle/jit/sot/translate.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, TypeVar + +import paddle + +from .opcode_translator import eval_frame_callback +from .utils import GraphLogger, StepInfoManager, StepState, log_do + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + + +def symbolic_translate(fn: Callable[P, R], **kwargs) -> Callable[P, R]: + """ + This function is the entry point of PaddleSOT. It sets eval_frame_callback before input + function to achieve Opcode-level translation. The translation process depends on the + simulation execution, in which information will be collected, especially the network + code. After the simulation execution is completed, the network code will be compiled + into a static graph Program to improve performance. + + Args: + fn: The input function. + + Returns: + Callable, The wrapped function. + + Examples: + >>> # doctest: +SKIP("Cound not get source code of function foo."") + >>> import paddle + >>> import numpy as np + >>> from sot.translate import symbolic_translate + >>> def foo(cond: paddle.Tensor, x: paddle.Tensor): + ... x += 1 + ... if cond: + ... x += 1 + ... else: + ... x -= 1 + ... return x + >>> symbolic_translate_foo = symbolic_translate(foo) + >>> # For the true branch, the output is 2. + >>> cond = paddle.to_tensor(True) + >>> x = paddle.to_tensor(0) + >>> dygraph_out = foo(cond, x) + >>> symbolic_translate_out = symbolic_translate_foo(cond, x) + >>> dygraph_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 2) + >>> symbolic_translate_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 2) + >>> np.testing.assert_allclose( + ... dygraph_out.numpy(), symbolic_translate_out.numpy() + ... ) + >>> # For the false branch, the output is 0. + >>> cond = paddle.to_tensor(False) + >>> dygraph_out = foo(cond, x) + >>> symbolic_translate_out = symbolic_translate_foo(cond, x) + >>> dygraph_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 0) + >>> symbolic_translate_out + Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, + 0) + >>> np.testing.assert_allclose( + ... dygraph_out.numpy(), symbolic_translate_out.numpy() + ... ) + + """ + + def callback(frame): + return eval_frame_callback(frame, **kwargs) + + def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: + assert hasattr( + fn, "__code__" + ), "Target function doesn't have code for simulating." + StepInfoManager().sot_step() + GraphLogger().clear() + paddle.framework.core.set_eval_frame(callback) + try: + outs = fn(*args, **kwargs) + except Exception as e: + raise e + finally: + paddle.framework.core.set_eval_frame(None) + + log_do(1, lambda: GraphLogger().print_info()) + return outs + + def impl_dynamic(*args: P.args, **kwargs: P.kwargs) -> R: + outs = fn(*args, **kwargs) + return outs + + def impl(*args: P.args, **kwargs: P.kwargs) -> R: + with StepInfoManager().step_guard(fn.__code__): + state = StepInfoManager().current_state + + if state == StepState.RUN_SOT: + return impl_sot(*args, **kwargs) + elif state == StepState.RUN_DYN: + return impl_dynamic(*args, **kwargs) + elif state == StepState.COLLECT_INFO: + return StepInfoManager().collect_info( + impl_dynamic, impl_sot, *args, **kwargs + ) + + return impl diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py new file mode 100644 index 00000000000000..a1f26ea622772b --- /dev/null +++ b/python/paddle/jit/sot/utils/__init__.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .code_status import CodeStatus # noqa: F401 +from .exceptions import ( # noqa: F401 + BreakGraphError, + FallbackError, + InnerError, + inner_error_default_handler, +) +from .magic_methods import magic_method_builtin_dispatch # noqa: F401 +from .paddle_api_config import ( # noqa: F401 + is_break_graph_tensor_methods, + is_inplace_api, + paddle_tensor_methods, +) +from .utils import ( # noqa: F401 + Cache, + GraphLogger, + NameGenerator, + OrderedSet, + ResumeFnNameFactory, + Singleton, + SotUndefinedVar, + StepInfoManager, + StepState, + cost_model, + count_if, + current_tmp_name_records, + execute_time, + flatten_extend, + get_unbound_method, + hashable, + in_paddle_module, + is_break_graph_api, + is_builtin_fn, + is_clean_code, + is_paddle_api, + is_strict_mode, + list_contain_by_id, + list_find_index_by_id, + log, + log_do, + map_if, + map_if_extend, + meta_str, + min_graph_size, + no_eval_frame, + show_trackers, + tmp_name_guard, +) diff --git a/python/paddle/jit/sot/utils/code_status.py b/python/paddle/jit/sot/utils/code_status.py new file mode 100644 index 00000000000000..007e77f6340041 --- /dev/null +++ b/python/paddle/jit/sot/utils/code_status.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from enum import Enum + +import paddle + +from .utils import Singleton, log + + +class CodeState(Enum): + UNKNOW = 1 + WITH_GRAPH = 2 + WITHOUT_GRAPH = 3 + + +class CodeInfo: + def __init__(self): + self.state = CodeState.UNKNOW + self.counter = 0 + + def __repr__(self): + return f"state: {self.state}, counter: {self.counter}" + + +@Singleton +class CodeStatus: + WITH_GRAPH_API = [ + paddle.nn.Layer.__call__.__code__, + paddle.nn.Layer._dygraph_call_func.__code__, + ] + + def __init__(self): + self.code_map = {} + self.setup_code_map() + + def setup_code_map(self): + for code in self.WITH_GRAPH_API: + info = CodeInfo() + info.state = CodeState.WITH_GRAPH + self.code_map[code] = info + + def clear(self): + self.code_map.clear() + self.setup_code_map() + + def is_code_without_graph(self, code): + if code not in self.code_map: + info = CodeInfo() + self.code_map[code] = info + else: + info = self.code_map[code] + + if info.state == CodeState.WITHOUT_GRAPH: + return True + if info.state == CodeState.UNKNOW: + info.counter += 1 + if info.counter >= 10: + log( + 3, + f"[CodeStatus] Switch state to WITHOUT_GRAPH for {code}\n", + ) + info.state = CodeState.WITHOUT_GRAPH + return False + + def trace_back_frames(self): + frame = inspect.currentframe() + while frame.f_back is not None: + frame = frame.f_back + code = frame.f_code + if code in self.code_map: + info = self.code_map[code] + if info.state != CodeState.WITH_GRAPH: + log( + 3, + f"[CodeStatus] Switch state to WITH_GRAPH for {code}\n", + ) + info.state = CodeState.WITH_GRAPH diff --git a/python/paddle/jit/sot/utils/exceptions.py b/python/paddle/jit/sot/utils/exceptions.py new file mode 100644 index 00000000000000..ff26f4ee2ba107 --- /dev/null +++ b/python/paddle/jit/sot/utils/exceptions.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback + + +class SotErrorBase(Exception): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + from ..opcode_translator.breakpoint import BreakpointManager + + BreakpointManager().on_event(f"{self.__class__.__name__}") + + def print(self): + lines = traceback.format_tb(self.__traceback__) + print("".join(lines)) + + +class InnerError(SotErrorBase): + pass + + +class HasNoAttributeError(InnerError): + pass + + +class FallbackError(SotErrorBase): + def __init__(self, msg, disable_eval_frame=False): + super().__init__(msg) + self.disable_eval_frame = disable_eval_frame + + +# raise in inline function call strategy. +class BreakGraphError(SotErrorBase): + pass + + +def inner_error_default_handler(func, message_fn): + """Wrap function and an error handling function and throw an InnerError.""" + + def impl(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + message = message_fn(*args, **kwargs) + origin_exception_message = "\n".join( + traceback.format_exception(type(e), e, e.__traceback__) + ) + raise InnerError( + f"{message}.\nOrigin Exception is: \n {origin_exception_message}" + ) from e + + return impl diff --git a/python/paddle/jit/sot/utils/magic_methods.py b/python/paddle/jit/sot/utils/magic_methods.py new file mode 100644 index 00000000000000..56b20abdb05419 --- /dev/null +++ b/python/paddle/jit/sot/utils/magic_methods.py @@ -0,0 +1,130 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import operator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable + +from .utils import hashable + +if TYPE_CHECKING: + BinaryOp = Callable[[Any, Any], Any] + UnaryOp = Callable[[Any], Any] + + +INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[BinaryOp, tuple[str, BinaryOp]] = { + # inplace op fn: (magic name, non-inplace op fn) + operator.iadd: ("__iadd__", operator.add), + operator.iand: ("__iand__", operator.and_), + operator.iconcat: ("__iconcat__", operator.concat), + operator.ifloordiv: ("__ifloordiv__", operator.floordiv), + operator.ilshift: ("__ilshift__", operator.lshift), + operator.imatmul: ("__imatmul__", operator.matmul), + operator.imod: ("__imod__", operator.mod), + operator.imul: ("__imul__", operator.mul), + operator.ior: ("__ior__", operator.or_), + operator.ipow: ("__ipow__", operator.pow), + operator.irshift: ("__irshift__", operator.rshift), + operator.isub: ("__isub__", operator.sub), + operator.itruediv: ("__itruediv__", operator.truediv), + operator.ixor: ("__ixor__", operator.xor), +} + +NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES: dict[ + BinaryOp, tuple[str, str | None] +] = { + # op fn: (magic name, reverse magic name) + operator.add: ("__add__", "__radd__"), + operator.and_: ("__and__", "__rand__"), + operator.contains: ("__contains__", None), + operator.delitem: ("__delitem__", None), + operator.eq: ("__eq__", "__eq__"), + operator.floordiv: ("__floordiv__", "__rfloordiv__"), + operator.ge: ("__ge__", "__le__"), + operator.getitem: ("__getitem__", None), + operator.gt: ("__gt__", "__lt__"), + operator.le: ("__le__", "__ge__"), + operator.lshift: ("__lshift__", "__rlshift__"), + operator.lt: ("__lt__", "__gt__"), + operator.matmul: ("__matmul__", "__rmatmul__"), + operator.mod: ("__mod__", "__rmod__"), + operator.mul: ("__mul__", "__rmul__"), + operator.ne: ("__ne__", "__ne__"), + operator.or_: ("__or__", "__ror__"), + operator.pow: ("__pow__", "__rpow__"), + operator.rshift: ("__rshift__", "__rrshift__"), + operator.sub: ("__sub__", "__rsub__"), + operator.truediv: ("__truediv__", "__rtruediv__"), + operator.xor: ("__xor__", "__rxor__"), +} + +UNARY_OPS_TO_MAGIC_NAMES: dict[UnaryOp, str] = { + operator.neg: "__neg__", + operator.invert: "__invert__", + operator.pos: "__pos__", + operator.abs: "__abs__", + operator.index: "__index__", + operator.inv: "__inv__", + operator.invert: "__invert__", + operator.not_: "__not__", + operator.pos: "__pos__", + operator.truth: "__bool__", + bool: "__bool__", + abs: "__abs__", + float: "__float__", + len: "__len__", + int: "__int__", +} +# TODO(SigureMo): support any, all, sum + + +INPLACE_BINARY_OPS = set(INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys()) +NON_INPLACE_BINARY_OPS = set(NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES.keys()) +BINARY_OPS = INPLACE_BINARY_OPS | NON_INPLACE_BINARY_OPS +UNARY_OPS = set(UNARY_OPS_TO_MAGIC_NAMES.keys()) + + +@dataclass +class MagicMethod: + name: str + is_inplace: bool = False + is_reverse: bool = False + + +def magic_method_builtin_dispatch(fn: BinaryOp | UnaryOp) -> list[MagicMethod]: + if not hashable(fn): + return [] + if fn in INPLACE_BINARY_OPS: + inplace_magic_name, non_inplace_op = INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ + fn + ] + return [ + MagicMethod(inplace_magic_name, is_inplace=True) + ] + magic_method_builtin_dispatch(non_inplace_op) + elif fn in NON_INPLACE_BINARY_OPS: + magic_name, reverse_magic_name = NON_INPLACE_BINARY_OPS_TO_MAGIC_NAMES[ + fn + ] + magic_methods = [MagicMethod(magic_name)] + if reverse_magic_name is not None: + magic_methods.append( + MagicMethod(reverse_magic_name, is_reverse=True) + ) + return magic_methods + elif fn in UNARY_OPS: + magic_name = UNARY_OPS_TO_MAGIC_NAMES[fn] + return [MagicMethod(magic_name)] + return [] diff --git a/python/paddle/jit/sot/utils/paddle_api_config.py b/python/paddle/jit/sot/utils/paddle_api_config.py new file mode 100644 index 00000000000000..06852d186a76c5 --- /dev/null +++ b/python/paddle/jit/sot/utils/paddle_api_config.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import paddle + + +def is_inplace_api(func): + inplace_apis = {paddle.static.setitem} + return func in inplace_apis + + +def get_tensor_methods(): + return [ + member_name + for member_name, member in inspect.getmembers(paddle.static.Variable) + if inspect.isfunction(member) + ] + + +def get_paddle_api(): + modules = [ + paddle, + paddle.nn.functional, + paddle.linalg, + paddle.signal, + paddle.fft, + paddle.vision.ops, + ] + special_paddle_apis = [paddle.tensor.fill_constant] + non_operator_related_apis = [ + paddle.in_dynamic_mode, + paddle.save, + paddle.load, + paddle.get_cuda_rng_state, + paddle.set_rng_state, + paddle.set_cuda_rng_state, + paddle.get_rng_state, + paddle.set_default_dtype, + paddle.check_shape, + paddle.summary, + paddle.finfo, + paddle.iinfo, + paddle.enable_static, + paddle.disable_static, + paddle.is_grad_enabled, + ] + # TODO: users should not call static_apis, but we need to use, so add static_apis here temporary + static_apis = [paddle.static.setitem, paddle.static.accuracy] + paddle_api_list = [] + for module in modules: + for fn_name in getattr(module, "__all__", []): + fn = getattr(module, fn_name) + if inspect.isfunction(fn): + paddle_api_list.append(fn) + return list( + set(special_paddle_apis) + | set(static_apis) + | set(paddle_api_list) - set(non_operator_related_apis) + ) + + +paddle_tensor_methods = get_tensor_methods() +paddle_api_list = get_paddle_api() + +# TODO(Aurelius84): It seems that we use it to judge 'in_paddle_module()'. +# Bug what does 'is_paddle_module' really means? Is all paddle.xx sub module +# considered as paddle module? +paddle_api_module_prefix = { + "paddle.nn.functional", + "paddle.nn.layer.activation", +} + +break_graph_set = set() + + +break_graph_tensor_method = { + 'register_hook', + 'numpy', + 'clear_gradient', + # TODO: Browse all possible functions and make prior judgments. +} + + +def is_break_graph_tensor_methods(method_name): + return method_name in break_graph_tensor_method + + +def add_break_graph_apis(apis: list): + break_graph_set.update(apis) diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py new file mode 100644 index 00000000000000..912ae7dec2692c --- /dev/null +++ b/python/paddle/jit/sot/utils/utils.py @@ -0,0 +1,730 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import builtins +import inspect +import os +import time +import types +import weakref +from collections import OrderedDict +from contextlib import contextmanager +from enum import Enum +from typing import Any, Generic, Iterable, Iterator, TypeVar +from weakref import WeakValueDictionary + +import numpy as np + +import paddle +from paddle.framework import Program +from paddle.utils import flatten, map_structure + +from .paddle_api_config import ( + break_graph_set, + paddle_api_list, + paddle_api_module_prefix, +) + +T = TypeVar("T") + + +def cost_model(): + return os.environ.get("COST_MODEL", "False") == "True" + + +def min_graph_size(): + return int(os.environ.get("MIN_GRAPH_SIZE", 10)) + + +class Singleton(Generic[T]): + def __init__(self, cls: type[T]): + self._cls = cls + self._instance = {} + + def __call__(self) -> T: + if self._cls not in self._instance: + self._instance[self._cls] = self._cls() + return self._instance[self._cls] + + +class NameGenerator: + def __init__(self, prefix): + self.counter = 0 + self.prefix = prefix + + def next(self): + name = self.prefix + str(self.counter) + self.counter += 1 + return name + + def match_name(self, name: str) -> bool: + return name.startswith(self.prefix) + + +_tmp_name_records = None + + +class TmpNameRecords: + def __init__(self): + self.name_generator = NameGenerator(prefix="_sot_tmp_") + self.tmp_names_record = OrderedDict() + + def next_name(self): + return self.name_generator.next() + + def add_tmp_var(self, expr): + if expr in self.tmp_names_record: + return self.tmp_names_record[expr] + else: + tmp_name = self.next_name() + self.tmp_names_record[expr] = tmp_name + return tmp_name + + +@contextmanager +def tmp_name_guard(): + global _tmp_name_records + old = _tmp_name_records + _tmp_name_records = TmpNameRecords() + yield + _tmp_name_records = old + + +def current_tmp_name_records(): + global _tmp_name_records + return _tmp_name_records + + +@Singleton +class ResumeFnNameFactory: + def __init__(self) -> None: + self.gen = NameGenerator('resume_') + + def next(self): + name = self.gen.next() + return name + + +def log(level, *args): + cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + if level <= cur_level: + print(*args, end="") + + +def log_do(level, fn): + cur_level = int(os.environ.get("SOT_LOG_LEVEL", "0")) + if level <= cur_level: + fn() + + +def no_eval_frame(func): + def no_eval_frame_func(*args, **kwargs): + old_cb = paddle.framework.core.set_eval_frame(None) + try: + retval = func(*args, **kwargs) + except: + raise + finally: + paddle.framework.core.set_eval_frame(old_cb) + return retval + + return no_eval_frame_func + + +def is_paddle_api(func): + if isinstance(func, paddle.nn.Layer): # ignore all the classes + return False + if hasattr(func, "__self__"): # ignore all the methods + return False + if inspect.isclass( + func + ): # paddle.Tensor should not be wrapped, but how about other situations? + return False + return in_paddle_module(func) or func in paddle_api_list + + +def is_builtin_fn(fn): + special_builtin_fns = [weakref.ref] + if fn in special_builtin_fns: + return True + if isinstance(fn, types.BuiltinFunctionType): + return True + for member_name, member in inspect.getmembers(builtins): + if member is fn and isinstance(member, type): + return True + return False + + +def in_paddle_module(func): + if hasattr(func, "__module__"): + module_str = func.__module__ + if module_str is None: + return False + log(5, "find paddle function with __module__: ", module_str, "\n") + if hasattr(func, "__name__"): + log( + 5, " with __name__ : ", func.__name__, "\n" + ) + log(5, " with results : ") + for prefix in paddle_api_module_prefix: + if module_str.startswith(prefix): + log(5, " True\n") + return True + log(5, " False\n") + return False + + +def is_break_graph_api(func): + return func in break_graph_set + + +def map_if(*structures, pred, true_fn, false_fn): + def replace(*args): + if pred(*args): + return true_fn(*args) + return false_fn(*args) + + return map_structure(replace, *structures) + + +def flatten_extend(structure): + for item in flatten(structure): + if isinstance(item, slice): + yield item.start + yield item.stop + yield item.step + else: + yield item + + +def map_if_extend(structure, pred, true_fn, false_fn): + """support extended structures like slice and SliceVariable""" + + def wrapped_pred(x): + if isinstance(x, slice): + return True + return pred(x) + + def wrapped_true_fn(x): + if isinstance(x, (slice)): + l = [x.start, x.stop, x.step] + l = map_if_extend(l, pred, true_fn, false_fn) + return slice(*l) + return true_fn(x) + + return map_if( + structure, pred=wrapped_pred, true_fn=wrapped_true_fn, false_fn=false_fn + ) + + +def count_if(*structures, pred): + def is_true(*args): + if pred(*args): + return 1 + return 0 + + return sum(flatten(map_structure(is_true, *structures))) + + +class Cache: + def __init__(self, weak=False): + if not weak: + self.cache = {} + else: + self.cache = WeakValueDictionary() + self.hit_num = 0 + + def __call__(self, *args, **kwargs): + cache_key = self.key_fn(*args, **kwargs) + if cache_key is None: + return self.value_fn(*args, **kwargs) + if cache_key in self.cache: + log(5, "cache hit: ", cache_key, "\n") + self.hit_num += 1 + return self.cache[cache_key] + value = self.value_fn(*args, **kwargs) + self.cache[cache_key] = value + return value + + def clear(self): + self.cache.clear() + self.hit_num = 0 + + def key_fn(self, *args, **kwargs): + raise NotImplementedError() + + def value_fn(self, *args, **kwargs): + raise NotImplementedError() + + +def execute_time(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + execution_time = end_time - start_time + print("Execute time:", execution_time) + return result + + return wrapper + + +def meta_str(shape, dtype, stop_gradient): + return f"(shape: {shape}, dtype: {dtype}, stop_gradient: {stop_gradient})" + + +def is_strict_mode(): + return os.environ.get("STRICT_MODE", "0") == "1" + + +def show_trackers() -> str | None: + return os.environ.get("SHOW_TRACKERS", None) + + +def is_clean_code() -> bool: + return os.environ.get('CLEAN_CODE', "False") == "True" + + +def list_find_index_by_id(li: list[Any], item: Any) -> int: + return [id(it) for it in li].index(id(item)) + + +def list_contain_by_id(li: list[Any], item: Any) -> int: + return id(item) in [id(it) for it in li] + + +def get_unbound_method(obj, name): + # TODO(dev): Consider the case of patching methods to instances + return getattr(obj.__class__, name) + + +@Singleton +class GraphLogger: + graph_num: int + op_num: int + graphs: list[Program] + ops: list[paddle.base.framework.Operator] + + def __init__(self): + self.clear() + + def clear(self): + self.graph_num = 0 + self.op_num = 0 + self.graphs = [] + self.ops = [] + + def get_graph_num(self): + return self.graph_num + + def get_op_num(self): + return self.op_num + + def add_subgraph(self, program: Program): + self.graph_num += 1 + self.graphs.append(program) + + for block in program.blocks: + sub_op = [] + for op in block.ops: + self.op_num += 1 + sub_op.append(op) + self.ops.append(sub_op) + + def add_subgprah_info(self, strs): + for i in range(len(self.graphs)): + strs.append( + "------------------------------------------------------" + ) + + strs.append(f"subgraph {i}, OpNum: {len(self.ops[i])}") + strs.append(f"{self.graphs[i]}") + + def __str__(self): + strs = [] + strs.append("---------------- PaddleSOT graph info ----------------") + strs.append(f"SubgraphNum: {self.get_graph_num()}") + strs.append(f"OpNum: {self.get_op_num()}") + + # We can display every subgraph info + log_do(5, lambda: self.add_subgprah_info(strs)) + + strs.append("---------------- PaddleSOT graph info ----------------") + return "\n".join(strs) + + def __repr__(self): + return self.__str__() + + def print_info(self): + print(self) + + +@Singleton +class SotUndefinedVar: + pass + + +def hashable(obj): + try: + hash(obj) + return True + except TypeError as e: + return False + + +class OrderedSet(Generic[T]): + """ + A set that preserves the order of insertion. + """ + + _data: dict[T, None] + + def __init__(self, items: Iterable[T] | None = None): + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s + OrderedSet(1, 2, 3) + >>> s = OrderedSet() + >>> s + OrderedSet() + """ + self._data = dict.fromkeys(items) if items is not None else {} + + def __iter__(self) -> Iterator[T]: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> for item in s: + ... print(item) + 1 + 2 + 3 + """ + return iter(self._data) + + def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Union two sets. + + Args: + other: Another set to be unioned. + + Returns: + The union of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 | s2 + OrderedSet(1, 2, 3, 4) + """ + return OrderedSet(list(self) + list(other)) + + def __ior__(self, other: OrderedSet[T]): + """ + Union two sets in place. + + Args: + other: Another set to be unioned. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 |= s2 + >>> s1 + OrderedSet(1, 2, 3, 4) + """ + self._data.update(dict.fromkeys(other)) + return self + + def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Intersect two sets. + + Args: + other: Another set to be intersected. + + Returns: + The intersection of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 & s2 + OrderedSet(2, 3) + """ + return OrderedSet([item for item in self if item in other]) + + def __iand__(self, other: OrderedSet[T]): + """ + Intersect two sets in place. + + Args: + other: Another set to be intersected. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 &= s2 + >>> s1 + OrderedSet(2, 3) + """ + self._data = {item: None for item in self if item in other} + return self + + def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Subtract two sets. + + Args: + other: Another set to be subtracted. + + Returns: + The subtraction of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 - s2 + OrderedSet(1) + """ + return OrderedSet([item for item in self if item not in other]) + + def __isub__(self, other: OrderedSet[T]): + """ + Subtract two sets in place. + + Args: + other: Another set to be subtracted. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 -= s2 + >>> s1 + OrderedSet(1) + """ + self._data = {item: None for item in self if item not in other} + return self + + def add(self, item: T): + """ + Add an item to the set. + + Args: + item: The item to be added. + + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s.add(4) + >>> s + OrderedSet(1, 2, 3, 4) + """ + self._data.setdefault(item) + + def remove(self, item: T): + """ + Remove an item from the set. + + Args: + item: The item to be removed. + + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s.remove(2) + >>> s + OrderedSet(1, 3) + """ + del self._data[item] + + def __contains__(self, item: T) -> bool: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> 1 in s + True + >>> 4 in s + False + """ + return item in self._data + + def __len__(self) -> int: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> len(s) + 3 + """ + return len(self._data) + + def __bool__(self) -> bool: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> bool(s) + True + >>> s = OrderedSet() + >>> bool(s) + False + """ + return bool(self._data) + + def __eq__(self, other: object) -> bool: + """ + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([1, 2, 3]) + >>> s1 == s2 + True + >>> s3 = OrderedSet([3, 2, 1]) + >>> s1 == s3 + False + """ + if not isinstance(other, OrderedSet): + return NotImplemented + return list(self) == list(other) + + def __repr__(self) -> str: + data_repr = ", ".join(map(repr, self._data)) + return f"OrderedSet({data_repr})" + + +class StepState(Enum): + COLLECT_INFO = 1 + RUN_SOT = 2 + RUN_DYN = 3 + + +class StepInfo: + REQUIRED_DYN_INFOS = 10 + REQUIRED_SOT_INFOS = 10 + + USED_DYN_INFOS = 5 + + COLLECT_INFO_MAX_STEP = 50 + CV_BOUNDARY = 0.1 + + BACK_TRACE_STEPS = 20 + + def __init__(self): + self.step_count = -1 + self.state = ( + StepState.COLLECT_INFO if cost_model() else StepState.RUN_SOT + ) + self.dyn_time_costs = [] + self.avg_dyn_time = 0 + self.sot_time_costs = [] + self.sot_step = -1 + + def add_dynamic_time_info(self, time_cost): + self.dyn_time_costs.append(time_cost) + if len(self.dyn_time_costs) == self.REQUIRED_DYN_INFOS: + self.avg_dyn_time = np.mean( + self.dyn_time_costs[-self.USED_DYN_INFOS :] + ) + + def add_sot_time_info(self, time_cost, current_code): + self.sot_time_costs.append(time_cost) + if len(self.sot_time_costs) == self.REQUIRED_SOT_INFOS: + avg_sot_time = np.mean(self.sot_time_costs) + log( + 1, + f"[Cost Model] sot: {avg_sot_time}, dyn: {self.avg_dyn_time}\n", + ) + if avg_sot_time < self.avg_dyn_time: + log(1, f"[Cost Model] Switch to RUN_SOT: {current_code} \n") + self.state = StepState.RUN_SOT + elif ( + self.step_count > self.COLLECT_INFO_MAX_STEP + or np.std(self.sot_time_costs) / avg_sot_time < self.CV_BOUNDARY + ): + log(1, f"[Cost Model] Switch to RUN_DYN: {current_code}\n") + self.state = StepState.RUN_DYN + else: + log(1, f"[Cost Model] Decision delayed: {current_code}\n") + self.sot_time_costs.clear() + + def need_back_trace(self): + return self.step_count < self.BACK_TRACE_STEPS + + def need_dynamic_info(self): + return len(self.dyn_time_costs) < self.REQUIRED_DYN_INFOS + + +@Singleton +class StepInfoManager: + def __init__(self): + self.step_record = {} + self.current_code = None + self.current_step_info = None + + @contextmanager + def step_guard(self, code): + try: + old_code = self.current_code + old_info = self.current_step_info + + self.current_code = code + if code not in self.step_record: + self.step_record[code] = StepInfo() + self.current_step_info = self.step_record[code] + + self.current_step_info.step_count += 1 + + log( + 2, + f"[Cost Model] New step start, current state is {self.current_state}\n", + ) + yield + finally: + self.current_code = old_code + self.current_step_info = old_info + + def sot_step(self): + self.current_step_info.sot_step += 1 + + def collect_info(self, impl_dynamic, impl_sot, /, *args, **kwargs): + if self.current_step_info.need_dynamic_info(): + start_time = time.perf_counter() + outs = impl_dynamic(*args, **kwargs) + time_cost = time.perf_counter() - start_time + self.current_step_info.add_dynamic_time_info(time_cost) + else: + start_time = time.perf_counter() + outs = impl_sot(*args, **kwargs) + time_cost = time.perf_counter() - start_time + self.current_step_info.add_sot_time_info( + time_cost, self.current_code + ) + return outs + + @property + def need_back_trace(self): + return self.current_step_info.need_back_trace() + + @property + def current_step(self): + return self.current_step_info.step_count + + @property + def current_state(self): + return self.current_step_info.state + + def clear(self): + self.step_record.clear() + self.current_code = None + self.current_step = -1 diff --git a/python/paddle/jit/translated_layer.py b/python/paddle/jit/translated_layer.py index 766e72e0553e87..b5590e3194eef5 100644 --- a/python/paddle/jit/translated_layer.py +++ b/python/paddle/jit/translated_layer.py @@ -512,6 +512,11 @@ def _preprocess(self, program_desc): @switch_to_static_graph def _append_scale_to_output(self, program): + # 0. scale don't support bool output, we skip append scale for it + for out_desc in self._output_descs: + if out_desc.dtype() == core.VarDesc.VarType.BOOL: + return + # 1. append scale & save var scale_output_vars = [] with framework.program_guard(program): diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 7acafa290f7e0b..c74748793a4e9d 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1277,7 +1277,7 @@ def softplus(x, beta=1, threshold=20, name=None): \end{cases} Parameters: - x (Tensor): The input Tensor with data type float32, float64. + x (Tensor): The input Tensor with data type float32, float64, complex64, complex128. beta (float, optional): The value of :math:`\beta` for softplus. Default is 1 threshold (float, optional): The value of :math:`\varepsilon` for softplus. Default is 20 name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. @@ -1302,7 +1302,17 @@ def softplus(x, beta=1, threshold=20, name=None): return _C_ops.softplus(x, beta, threshold) else: check_variable_and_dtype( - x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'softplus' + x, + 'x', + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], + 'softplus', ) helper = LayerHelper('softplus', **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 28341db5588aed..62050410b9c1a8 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1196,7 +1196,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): # get mask shape input_shape = x.shape - if not in_dynamic_or_pir_mode(): + if not in_dynamic_mode(): input_shape_tensor = paddle.shape(x) drop_axes = [axis] if isinstance(axis, int) else list(axis) if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1: @@ -1212,7 +1212,7 @@ def get_attrs(prog, dropout_prob, is_test, seed): ) ) mask_shape = [1] * len(input_shape) - if not in_dynamic_or_pir_mode(): + if not in_dynamic_mode(): for i in drop_axes: mask_shape[i] = input_shape_tensor[i] else: @@ -1658,10 +1658,16 @@ def pad(x, pad, mode='constant', value=0.0, data_format="NCHW", name=None): paddings = pad pad_value = value - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): out = _C_ops.pad(x, paddings, float(pad_value)) return out + if in_pir_mode(): + if isinstance(pad_value, paddle.pir.OpResult): + return _C_ops.pad(x, paddings, pad_value) + else: + return _C_ops.pad(x, paddings, float(pad_value)) + check_variable_and_dtype( x, 'x', diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 6caf0370366f4d..138146f376aeeb 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -13,7 +13,7 @@ # limitations under the License. from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode -from paddle.base.framework import _global_flags +from paddle.base.framework import _global_flags, in_dynamic_or_pir_mode from paddle.device import ( get_all_custom_device_type, is_compiled_with_cuda, @@ -126,7 +126,7 @@ def _conv_nd( name=None, ): # Due to the poor performance of NHWC, we transpose the input to NCHW. - if in_dynamic_mode() and op_type == "conv2d": + if in_dynamic_or_pir_mode() and op_type == "conv2d": pre_bias = _C_ops.conv2d( x, weight, @@ -155,7 +155,7 @@ def _conv_nd( else: return pre_bias - if in_dynamic_mode() and op_type == "depthwise_conv2d": + if in_dynamic_or_pir_mode() and op_type == "depthwise_conv2d": pre_bias = _C_ops.depthwise_conv2d( x, weight, @@ -174,7 +174,7 @@ def _conv_nd( else: return pre_bias - if in_dynamic_mode() and op_type == "conv3d": + if in_dynamic_or_pir_mode() and op_type == "conv3d": pre_bias = _C_ops.conv3d( x, weight, diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 6f111c61cb5071..c9fbf78a4fb4d2 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -15,7 +15,11 @@ import numpy as np from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode -from paddle.base.framework import Variable, in_dygraph_mode +from paddle.base.framework import ( + Variable, + in_dygraph_mode, + in_dynamic_or_pir_mode, +) from ...base.data_feeder import check_type, check_variable_and_dtype from ...base.layer_helper import LayerHelper @@ -372,7 +376,7 @@ def avg_pool2d( padding, 2, channel_last, ceil_mode=ceil_mode ) - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): output = _C_ops.pool2d( x, kernel_size, @@ -1254,7 +1258,7 @@ def max_pool2d( "When setting return_mask to true, data_format must be set to NCHW in API:max_pool2d" ) - if in_dygraph_mode(): + if in_dynamic_or_pir_mode(): if return_mask: output = _C_ops.max_pool2d_with_index( x, kernel_size, stride, padding, False, False diff --git a/python/paddle/nn/initializer/normal.py b/python/paddle/nn/initializer/normal.py index c1bcb89f676f72..3a05bbed121f36 100644 --- a/python/paddle/nn/initializer/normal.py +++ b/python/paddle/nn/initializer/normal.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _C_ops +from paddle import _C_ops, pir from ...base import core, framework, unique_name from ...base.data_feeder import check_variable_and_dtype -from ...base.framework import _current_expected_place, in_dygraph_mode +from ...base.framework import ( + _current_expected_place, + in_dygraph_mode, + in_pir_mode, +) from .initializer import Initializer __all__ = [] @@ -54,7 +58,7 @@ def forward(self, var, block=None): """ block = self._check_block(block) - assert isinstance(block, framework.Block) + assert isinstance(block, (framework.Block, pir.Block)) check_variable_and_dtype( var, @@ -78,7 +82,17 @@ def forward(self, var, block=None): ) out_var._share_underline_tensor_to(var) return None - + elif in_pir_mode(): + place = _current_expected_place() + out_var = _C_ops.gaussian( + var.shape, + self._mean, + self._std_dev, + self._seed, + var.dtype, + place, + ) + return out_var else: op = block.append_op( type="gaussian_random", diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index b60a9e6818c8e4..8f962da6b6766c 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -164,7 +164,7 @@ def weight_only_linear( 'weight': [weight], 'weight_scale': [weight_scale], } - if bias: + if bias is not None: inputs["bias"] = [bias] attrs = {'weight_dtype': weight_dtype} diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py index 7cb628565cff95..8db65d61bb5bac 100644 --- a/python/paddle/nn/utils/transform_parameters.py +++ b/python/paddle/nn/utils/transform_parameters.py @@ -121,6 +121,7 @@ def parameters_to_vector(parameters, name=None): ) for i, param in enumerate(parameters): _inplace_reshape_dygraph(param, origin_shapes[i]) + out.stop_gradient = False return out diff --git a/python/paddle/optimizer/lr.py b/python/paddle/optimizer/lr.py index 0d24286fb40cdd..37a46f53707f11 100644 --- a/python/paddle/optimizer/lr.py +++ b/python/paddle/optimizer/lr.py @@ -45,6 +45,7 @@ 'MultiplicativeDecay', 'OneCycleLR', 'CyclicLR', + 'LinearLR', ] @@ -2229,6 +2230,125 @@ def get_lr(self): return lr +class LinearLR(LRScheduler): + r""" + Set the learning rate according to linear scheduler. + The learning rate will be firstly multiplied by start_factor and linearly increase to end learning rate. + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + total_steps (int): Number of iterations that the learning_rate reaches end learning_rate. + start_factor (float): Start learning rate is defined by `start_factor * learning_rate` . Default: 1./3. + end_factor (float) End learning rate is defined by `end_factor * learning_rate`. Default: 1.0. + last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate. + verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``LinearLR`` instance to schedule learning rate. + + Examples: + .. code-block:: python + :name: code-dynamic + + >>> # Example1: train on default dynamic graph mode + >>> import paddle + >>> import numpy as np + + >>> # train on default dynamic graph mode + >>> linear = paddle.nn.Linear(10, 10) + >>> scheduler = paddle.optimizer.lr.LinearLR(learning_rate=0.5, total_steps=5, verbose=True) + >>> sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters()) + >>> for epoch in range(5): + ... for batch_id in range(20): + ... x = paddle.uniform([10, 10]) + ... out = linear(x) + ... loss = paddle.mean(out) + ... loss.backward() + ... sgd.step() + ... sgd.clear_gradients() + ... scheduler.step() + + .. code-block:: python + :name: code-static + + >>> # Example2: train on static graph mode + >>> import paddle + >>> import numpy as np + >>> paddle.enable_static() + >>> main_prog = paddle.static.Program() + >>> start_prog = paddle.static.Program() + >>> with paddle.static.program_guard(main_prog, start_prog): + ... x = paddle.static.data(name='x', shape=[None, 4, 5]) + ... y = paddle.static.data(name='y', shape=[None, 4, 5]) + ... z = paddle.static.nn.fc(x, 100) + ... loss = paddle.mean(z) + ... scheduler = paddle.optimizer.lr.LinearLR(learning_rate=0.5, + ... total_steps=5, verbose=True) + ... sgd = paddle.optimizer.SGD(learning_rate=scheduler) + ... sgd.minimize(loss) + ... + >>> exe = paddle.static.Executor() + >>> exe.run(start_prog) + >>> for epoch in range(5): + ... for batch_id in range(20): + ... out = exe.run( + ... main_prog, + ... feed={ + ... 'x': np.random.randn(3, 4, 5).astype('float32'), + ... 'y': np.random.randn(3, 4, 5).astype('float32') + ... }, + ... fetch_list=loss.name) + ... scheduler.step() + """ + + def __init__( + self, + learning_rate, + total_steps, + start_factor=1.0 / 3, + end_factor=1.0, + last_epoch=-1, + verbose=False, + ): + if start_factor > 1.0 or start_factor <= 0: + raise ValueError( + "`start_factor` must be greater than 0 and less or equal to 1, but got {}".format( + start_factor + ) + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "`end_factor` must be greater than 0 and less than 1, but got {}".format( + end_factor + ) + ) + + if total_steps <= 0: + raise ValueError( + f"`total_steps` must be greater than 0, but got {total_steps}" + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_steps = total_steps + + super().__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr * self.start_factor + elif self.last_epoch > self.total_steps: + return self.last_lr + else: + base_lr = self.total_steps * self.start_factor + cur_factor = self.end_factor - self.start_factor + factor = 1.0 + cur_factor / ( + base_lr + (self.last_epoch - 1) * cur_factor + ) + return self.last_lr * factor + + def autoincreased_step_counter(counter_name=None, begin=1, step=1): """ :api_attr: Static Graph diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 6d6ecfb220c69d..f25a5bf9f771b3 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -312,11 +312,11 @@ def state_dict(self): Examples: .. code-block:: python - import paddle - emb = paddle.nn.Embedding(10, 10) + >>> import paddle + >>> emb = paddle.nn.Embedding(10, 10) - adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) - state_dict = adam.state_dict() + >>> adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) + >>> state_dict = adam.state_dict() ''' state_dict = {} @@ -1306,6 +1306,16 @@ def backward( parameter_list = parameters if parameters else self._parameter_list with paddle.static.program_guard(program, startup_program): if in_pir_mode(): + if parameter_list is None: + # all parameters will be updated. + program_all_params = ( + program.global_block().all_parameters() + ) + parameter_list = [ + param + for param in program_all_params + if param.stop_gradient is False + ] params_grads = [] grads = paddle.autograd.ir_backward.grad( loss, parameter_list, no_grad_vars=act_no_grad_set diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index 4772e8f83280da..145eb103918bf2 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -25,6 +25,7 @@ ) from paddle.base.libpaddle.pir import ( # noqa: F401 translate_to_new_ir, + translate_to_new_ir_with_param_map, set_global_program, set_insertion_point, reset_insertion_point_to_start, diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 63111c4256f270..ad7ecbc4a1cd24 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -12,14 +12,371 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import warnings + +from paddle.base.libpaddle import DataType + from . import OpResult _already_patch_opresult = False +_supported_int_dtype_ = [ + DataType.BOOL, + DataType.UINT8, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, +] + + +def create_tensor_with_batchsize(ref_var, value, dtype): + assert isinstance(ref_var, OpResult) + value = float(value) + batch_dim = -1 + out_shape = [] + for i, d in enumerate(ref_var.shape): + if d < 0: + if batch_dim < 0: + batch_dim = i + out_shape.append(d) + else: + out_shape.append(1) + else: + out_shape.append(d) + assert batch_dim != -1 + from paddle import _C_ops + from paddle.framework import core + + out = _C_ops.full_batch_size_like( + ref_var, out_shape, dtype, value, batch_dim, batch_dim, core.Place() + ) + out.stop_gradient = True + + return out + def monkey_patch_opresult(): + def safe_get_dtype(var): + try: + dtype = var.dtype + except: + raise ValueError("Cannot get data type from var") + return dtype + + def place(self): + """ + OpResult don't have 'place' interface in static graph mode + But this interface can greatly facilitate dy2static. + So we give a warnning here and return None. + """ + warnings.warn( + "OpResult do not have 'place' interface for pir graph mode, try not to use it. None will be returned." + ) + + @property + def _ndim(self): + """ + Returns the dimension of current OpResult + + Returns: + the dimension + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> # create a static OpResult + >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) + >>> # print the dimension of the OpResult + >>> print(x.ndim) + 3 + """ + return len(self.shape) + + def ndimension(self): + """ + Returns the dimension of current OpResult + + Returns: + the dimension + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> # create a static OpResult + >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) + >>> # print the dimension of the OpResult + >>> print(x.ndimension()) + 3 + """ + return len(self.shape) + + def dim(self): + """ + Returns the dimension of current OpResult + + Returns: + the dimension + + Examples: + .. code-block:: python + + >>> import paddle + + >>> paddle.enable_static() + + >>> # create a static OpResult + >>> x = paddle.static.data(name='x', shape=[3, 2, 1]) + >>> # print the dimension of the OpResult + >>> print(x.dim()) + 3 + """ + return len(self.shape) + + def _item(self): + """ + In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self. + It will check that the shape must be a 1-D tensor + """ + if len(self.shape) > 1: + raise TypeError( + f"Required input var should be 1-D OpResult, but received {self.shape}" + ) + return self + + def astype(self, dtype): + """ + **Notes**: + + Cast a OpResult to a specified data type. + + Args: + + self(OpResult): The source OpResult + + dtype: The target data type + + Returns: + OpResult: OpResult with new dtype + + Examples: + In Static Graph Mode: + + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + >>> startup_prog = paddle.static.Program() + >>> main_prog = paddle.static.Program() + >>> with paddle.static.program_guard(startup_prog, main_prog): + ... original_value = paddle.static.data(name = "new_value", shape=[2,2], dtype='float32') + ... new_value = original_value.astype('int64') + ... print("new value's dtype is: {}".format(new_value.dtype)) + ... + new OpResult's dtype is: paddle.int64 + + """ + from paddle import _C_ops + + if not isinstance(dtype, DataType): + dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) + return _C_ops.cast(self, dtype) + + def _scalar_add_(var, value): + return paddle.scale(var, 1.0, value) + + def _scalar_sub_(var, value): + return paddle.scale(var, 1.0, -value) + + def _scalar_rsub_(var, value): + return paddle.scale(var, -1.0, value) + + def _scalar_mul_(var, value): + return paddle.scale(var, value, 0.0) + + def _scalar_div_(var, value): + return paddle.scale(var, 1.0 / value, 0.0) + + def _binary_creator_( + method_name, + python_api, + reverse=False, + scalar_method=None, + ): + def __impl__(self, other_var): + # 1. scalar exists cases + # we need combine the tensor.dtype and scalar.dtype, cast correct object + if isinstance(other_var, float): + # in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float + if self.dtype in _supported_int_dtype_: + self = astype(self, DataType.FLOAT32) + # here use `scale` replace `elementwise` to get better performance + # but only +, -, *, / can use this method + if scalar_method is not None: + return scalar_method(self, other_var) + elif isinstance(other_var, int): + # in all cases(+, -, *, /, **, //, %), we can cast it to float + # because the output tensor.dtype depend on the type of input tensor + other_var = float(other_var) + # division is a special case + # NOTE(chenweihang): because we cast tensor to float32 instead float64, + # the division result can only guarantee the numerical accuracy of 6 digits + # after the decimal point. The result of numpy calculation is of float64 type, + # so the calculation result here and the calculation result of numpy are + # different after 6 decimal point. If necessary, we can also use float64 here. + # torch's behavior here is consistent with ours + if ( + python_api == paddle.divide + and self.dtype in _supported_int_dtype_ + ): + paddle.cast(self, DataType.FLOAT32) + # here use `scale` replace `elementwise` to get better performance + # but only +, -, *, / can use this method + if scalar_method is not None: + return scalar_method(self, other_var) + else: + # do nothing + pass + + # 2. create OpResult for scalar + lhs_dtype = safe_get_dtype(self) + other_var_opresult = other_var + if not isinstance(other_var, OpResult): + if reverse: + for elem in self.shape: + if elem < 0: + other_var_opresult = create_tensor_with_batchsize( + self, other_var, lhs_dtype + ) + + break + else: + # when break is not triggered, enter the else branch + other_var_opresult = paddle.fill_constant( + self.shape, + lhs_dtype, + other_var, + ) + else: + # add fill_op to current_block + other_var_opresult = paddle.fill_constant( + [], + lhs_dtype, + other_var, + ) + + # 3. unify right var type to left var + rhs_dtype = safe_get_dtype(other_var_opresult) + if lhs_dtype != rhs_dtype: + other_var_opresult = paddle.cast(other_var_opresult, lhs_dtype) + if reverse: + tmp = self + self = other_var_opresult + other_var_opresult = tmp + + if ( + python_api == paddle.divide + ) and self.dtype in _supported_int_dtype_: + self = paddle.cast(self, DataType.FLOAT32) + other_var = paddle.cast(other_var_opresult, DataType.FLOAT32) + + out = python_api(self, other_var_opresult) + return out + + __impl__.__doc__ = """ + Args: + self(OpResult): left hand OpResult + other_var(OpResult|float|int): right hand OpResult + + Returns: + OpResult + """ + __impl__.__name__ = method_name + return __impl__ + + import paddle + + opresult_methods = [ + ('place', place), + ('item', _item), + ('dim', dim), + ('ndimension', ndimension), + ('ndim', _ndim), + ('astype', astype), + ( + '__add__', + _binary_creator_('__add__', paddle.tensor.add, False, _scalar_add_), + ), + # a+b == b+a. Do not need to reverse explicitly + ( + '__radd__', + _binary_creator_( + '__radd__', paddle.tensor.add, False, _scalar_add_ + ), + ), + ( + '__sub__', + _binary_creator_( + '__sub__', paddle.tensor.subtract, False, _scalar_sub_ + ), + ), + ( + '__rsub__', + _binary_creator_( + '__rsub__', paddle.tensor.subtract, True, _scalar_rsub_ + ), + ), + ( + '__mul__', + _binary_creator_( + '__mul__', paddle.tensor.multiply, False, _scalar_mul_ + ), + ), + # a*b == b*a. Do not need to reverse explicitly + ( + '__rmul__', + _binary_creator_( + '__rmul__', paddle.tensor.multiply, False, _scalar_mul_ + ), + ), + ( + '__div__', + _binary_creator_( + '__div__', paddle.tensor.divide, False, _scalar_div_ + ), + ), + ( + '__truediv__', + _binary_creator_( + '__truediv__', paddle.tensor.divide, False, _scalar_div_ + ), + ), + ( + '__rdiv__', + _binary_creator_('__rdiv__', paddle.tensor.divide, True, None), + ), + ( + '__rtruediv__', + _binary_creator_('__rtruediv__', paddle.tensor.divide, True, None), + ), + ] + global _already_patch_opresult if not _already_patch_opresult: + for method in opresult_methods: + method_name = method[0] + method_impl = method[1] + setattr(OpResult, method_name, method_impl) + # Handling Tensor Methods import paddle.tensor diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 76fa477ca8403c..a0f513061fd258 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -1224,6 +1224,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): with true_cond_block.block(): origin_true_output = true_fn() if origin_true_output is not None: + origin_true_output = map_structure( + create_undefined_var_in_subblock, origin_true_output + ) true_output = map_structure( copy_to_parent_func, origin_true_output ) @@ -1240,6 +1243,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): with false_cond_block.block(): origin_false_output = false_fn() if origin_false_output is not None: + origin_false_output = map_structure( + create_undefined_var_in_subblock, origin_false_output + ) false_output = map_structure( copy_to_parent_func, origin_false_output ) @@ -1356,6 +1362,18 @@ def merge_every_var_list(false_vars, true_vars, name): return merged_output +def create_undefined_var_in_subblock(var): + # to make sure the undefined var created in subblock. + from paddle.jit.dy2static.utils import ( + UndefinedVar, + create_undefined_variable_local, + ) + + if isinstance(var, UndefinedVar): + var = create_undefined_variable_local() + return var + + def copy_var_to_parent_block(var, layer_helper): if not isinstance(var, Variable): return var @@ -1711,7 +1729,7 @@ def Print( check_variable_and_dtype( input, 'input', - ['float32', 'float64', 'int32', 'int64', 'bool'], + ['uint16', 'float16', 'float32', 'float64', 'int32', 'int64', 'bool'], 'paddle.static.Print', ) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 61005132276d91..ce4cfc8ee883ba 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -329,6 +329,8 @@ from .math import polygamma_ # noqa: F401 from .math import renorm # noqa: F401 from .math import renorm_ # noqa: F401 +from .math import hypot # noqa: F401 +from .math import hypot_ # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -464,6 +466,8 @@ 'sum', 'nan_to_num', 'nan_to_num_', + 'hypot', + 'hypot_', 'nansum', 'nanmean', 'count_nonzero', diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 918b5f2c01e9cf..e5b589769d6272 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -36,6 +36,7 @@ from ..framework import ( LayerHelper, _current_expected_place, + _current_expected_place_, _get_paddle_place, convert_np_dtype_to_dtype_, core, @@ -651,10 +652,11 @@ def _handle_np_dtype(ndarray, dtype): def _to_tensor_static(data, dtype=None, stop_gradient=None): - if isinstance(data, Variable): + if isinstance(data, (Variable, paddle.pir.OpResult)): output = data if dtype is not None and dtype != data.dtype: output = paddle.cast(output, dtype) + else: if isinstance(data, np.number): # Special case for numpy scalars data = np.array(data) @@ -692,6 +694,9 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None): # fix numpy default dtype if data.dtype in ['float16', 'float32', 'float64']: data = data.astype(paddle.get_default_dtype()) + # Windows default type is 'int32', while Linux/Mac is 'int64'. Unify they. + elif data.dtype in ['int32']: + data = data.astype("int64") if dtype: target_dtype = dtype @@ -782,8 +787,7 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): """ place = _get_paddle_place(place) if place is None: - place = _current_expected_place() - + place = _current_expected_place_() if in_dynamic_mode(): return _to_tensor_non_static(data, dtype, place, stop_gradient) @@ -791,7 +795,6 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): else: re_exp = re.compile(r'[(](.+?)[)]', re.S) place_str = re.findall(re_exp, str(place))[0] - with paddle.static.device_guard(place_str): return _to_tensor_static(data, dtype, stop_gradient) @@ -884,24 +887,34 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): place = _current_expected_place() if force_cpu: place = core.CPUPlace() - if isinstance(shape, (list, tuple)): - shape = paddle.utils.convert_shape_to_list(shape) if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) + if in_dynamic_mode(): + value = float(value) + if isinstance(shape, (list, tuple)): + shape = paddle.utils.convert_shape_to_list(shape) + + else: + if isinstance(shape, (list, tuple)): + if paddle.utils._contain_var(shape): + shape = paddle.utils.get_pir_shape_tensor(shape, place) + elif isinstance(shape, paddle.pir.OpResult): + pass + else: + TypeError("Shape only supports OpReslut, or list, or tuple.") + if out is None: - value = float(value) if in_dynamic_mode() else value out = _C_ops.full(shape, value, dtype, place) out.stop_gradient = True return out if out is not None: - value = float(value) if in_dynamic_mode() else value - # final state mode is support out is not None. _C_ops.full_(out, shape, value, dtype, place) out.stop_gradient = True return out + else: attrs = {'force_cpu': force_cpu} dtype = convert_dtype(dtype) @@ -2264,22 +2277,16 @@ def convert_scalar(x): ) dtype = core.DataType.FLOAT32 - if dtype == core.VarDesc.VarType.BOOL or dtype == core.DataType.BOOL: + if dtype in [core.VarDesc.VarType.BOOL, core.DataType.BOOL]: value_name = "bool_values" values = [int(v) for v in input.flat] - elif ( - dtype == core.VarDesc.VarType.FP32 or dtype == core.DataType.FLOAT32 - ): + elif dtype in [core.VarDesc.VarType.FP32, core.DataType.FLOAT32]: value_name = "fp32_values" values = [float(v) for v in input.flat] - elif ( - dtype == core.VarDesc.VarType.INT32 or dtype == core.DataType.INT32 - ): + elif dtype in [core.VarDesc.VarType.INT32, core.DataType.INT32]: value_name = "int32_values" values = [int(v) for v in input.flat] - elif ( - dtype == core.VarDesc.VarType.INT64 or dtype == core.DataType.INT64 - ): + elif dtype in [core.VarDesc.VarType.INT64, core.DataType.INT64]: value_name = "int64_values" values = [int(v) for v in input.flat] else: diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 97172e39b5492a..71016a2208c154 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2621,7 +2621,7 @@ def eig(x, name=None): (-0.21026138961315155+0j)]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.eig(x) else: check_variable_and_dtype( @@ -2692,7 +2692,7 @@ def eigvals(x, name=None): ) ) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.eigvals(x) else: check_variable_and_dtype( diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 0deeefcc15c745..9b50993b891667 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -718,7 +718,7 @@ def greater_than(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [False, False, True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.greater_than(x, y) else: check_variable_and_dtype( @@ -807,7 +807,7 @@ def less_equal(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [True , True , False]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.less_equal(x, y) else: check_variable_and_dtype( @@ -896,7 +896,7 @@ def less_than(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True , False]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.less_than(x, y) else: check_variable_and_dtype( @@ -985,7 +985,7 @@ def not_equal(x, y, name=None): Tensor(shape=[3], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True , True ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.not_equal(x, y) else: check_variable_and_dtype( diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 15d9eb5300a5af..4e100fd96a8e3c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2485,7 +2485,7 @@ def unique( else: axis = [axis] attr_dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out, indices, inverse, counts = _C_ops.unique( x, return_index, return_inverse, return_counts, axis, attr_dtype ) @@ -3170,25 +3170,21 @@ def tile(x, repeat_times, name=None): # Tensor(shape=[1, 6], dtype=int32, place=Place(gpu:0), stop_gradient=True, # [[1, 2, 3, 1, 2, 3]]) """ - if in_dynamic_or_pir_mode(): - if isinstance(repeat_times, core.eager.Tensor): - assert ( - repeat_times.ndim == 1 - ), "Only support ndim == 1 while repeat_times is a Tensor." - repeat_times = repeat_times.tolist() - return _C_ops.tile(x, repeat_times) - else: + def check_input(x, repeat_times): check_type( - repeat_times, 'repeat_times', (list, tuple, Variable), 'tile' + repeat_times, + 'repeat_times', + (list, tuple, Variable, paddle.pir.OpResult), + 'tile', ) - if isinstance(repeat_times, Variable): + if isinstance(repeat_times, (Variable, paddle.pir.OpResult)): assert ( - repeat_times.numel() == 1 - ), 'repeat_times must be a Tensor with one element.' + len(repeat_times.shape) == 1 + ), 'repeat_times must be a Tensor with ndim == 1.' else: for elem in repeat_times: - if isinstance(elem, Variable): + if isinstance(elem, (Variable, paddle.pir.OpResult)): assert ( elem.numel() == 1 ), 'Elements in repeat_times must be Tensor with one element or integers.' @@ -3219,15 +3215,29 @@ def tile(x, repeat_times, name=None): "some_var.stop_gradient == True supporting some_var is the input." ) - helper = LayerHelper('tile', **locals()) + if in_dynamic_mode(): + if isinstance(repeat_times, core.eager.Tensor): + assert ( + repeat_times.ndim == 1 + ), "Only support ndim == 1 while repeat_times is a Tensor." + repeat_times = repeat_times.tolist() - inputs = {"X": [x]} - attrs = {} + return _C_ops.tile(x, repeat_times) + elif in_pir_mode(): + check_input(x, repeat_times) + if isinstance(repeat_times, (list, tuple)): + if paddle.utils._contain_var(repeat_times): + repeat_times = paddle.utils._convert_to_tensor_list( + repeat_times + ) + return _C_ops.tile(x, repeat_times) + else: + check_input(x, repeat_times) def get_attr_repeat_times(list_repeat_times): attrs_repeat_times = [] for idx, times in enumerate(list_repeat_times): - if isinstance(times, Variable): + if isinstance(times, (Variable, paddle.pir.OpResult)): attrs_repeat_times.append(-1) else: attrs_repeat_times.append(times) @@ -3236,6 +3246,11 @@ def get_attr_repeat_times(list_repeat_times): ), "All elements in repeat_times must be positive for tile." return attrs_repeat_times + helper = LayerHelper('tile', **locals()) + + inputs = {"X": [x]} + attrs = {} + if isinstance(repeat_times, Variable): repeat_times.stop_gradient = True inputs['RepeatTimes'] = repeat_times diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 811f28c1ba97b3..75a0e714ba4a79 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -998,7 +998,7 @@ def remainder(x, y, name=None): [0, 3, 2, 1]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.remainder(x, y) else: return _elementwise_op(LayerHelper('elementwise_mod', **locals())) @@ -1288,7 +1288,7 @@ def minimum(x, y, name=None): Tensor(shape=[3], dtype=float64, place=Place(cpu), stop_gradient=True, [ 1. , -inf., 5. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.minimum(x, y) else: return _elementwise_op(LayerHelper('elementwise_min', **locals())) @@ -2937,7 +2937,7 @@ def min(x, axis=None, keepdim=False, name=None): [0., 0.]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.min(x, axis, keepdim) else: reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) @@ -4833,7 +4833,7 @@ def any(x, axis=None, keepdim=False, name=None): [True]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.any(x, axis, keepdim) else: reduce_all, axis = _get_reduce_axis(axis, x) @@ -6932,3 +6932,56 @@ def ldexp_(x, y, name=None): y = paddle.cast(y, dtype=out_dtype) two = paddle.to_tensor(2, dtype=out_dtype) return paddle.multiply_(x, paddle.pow(two, y)) + + +def hypot(x, y, name=None): + """ + Calculate the length of the hypotenuse of a right-angle triangle. The equation is: + + .. math:: + out = {\\sqrt{x^2 + y^2}} + + Args: + x (Tensor): The input Tensor, the data type is float32, float64, int32 or int64. + y (Tensor): The input Tensor, the data type is float32, float64, int32 or int64. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. And the data type is float32 or float64. + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([3], dtype='float32') + >>> y = paddle.to_tensor([4], dtype='float32') + >>> res = paddle.hypot(x, y) + >>> print(res) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + [5.]) + + """ + if not isinstance(x, (paddle.Tensor, Variable)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + if not isinstance(y, (paddle.Tensor, Variable)): + raise TypeError(f"y must be tensor type, but got {type(y)}") + + out = (paddle.pow(x, 2) + paddle.pow(y, 2)).sqrt() + return out + + +@inplace_apis_in_dygraph_only +def hypot_(x, y, name=None): + r""" + Inplace version of ``hypot`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_hypot`. + """ + if not isinstance(x, (paddle.Tensor, Variable)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + if not isinstance(y, (paddle.Tensor, Variable)): + raise TypeError(f"y must be tensor type, but got {type(y)}") + + out = x.pow_(2).add_(y.pow(2)).sqrt_() + return out diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 54505f952bdfc0..beea2e7d904250 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -564,7 +564,7 @@ def cos(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.92106098, 0.98006660, 0.99500418, 0.95533651]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.cos(x) else: check_variable_and_dtype( @@ -754,7 +754,7 @@ def floor(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-1., -1., 0., 0.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.floor(x) else: check_variable_and_dtype( @@ -839,7 +839,7 @@ def round(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-1., -0., 1., 2.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.round(x) else: check_variable_and_dtype( @@ -916,7 +916,7 @@ def sigmoid(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.40131235, 0.45016602, 0.52497917, 0.57444251]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.sigmoid(x) else: check_variable_and_dtype( @@ -963,7 +963,7 @@ def sin(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.38941833, -0.19866933, 0.09983342, 0.29552022]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.sin(x) else: check_variable_and_dtype( @@ -1097,7 +1097,7 @@ def square(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.16000001, 0.04000000, 0.01000000, 0.09000000]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.square(x) else: check_variable_and_dtype( diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index f87e669cf198ef..feda36c2e85d49 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -796,6 +796,10 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): if in_dynamic_or_pir_mode(): shape = paddle.utils.convert_shape_to_list(shape) + if in_pir_mode() and paddle.utils._contain_var(shape): + shape = paddle.utils.get_pir_shape_tensor( + shape, _current_expected_place() + ) return _C_ops.uniform( shape, dtype, diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 5d9c9f6faa03be..c33bd0cd4f4159 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -171,7 +171,9 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): print(out4) # [[2, 2, 0, 1]] """ - if axis is not None and not isinstance(axis, (int, Variable)): + if axis is not None and not isinstance( + axis, (int, Variable, paddle.pir.OpResult) + ): raise TypeError( "The type of 'axis' must be int or Tensor or None in argmax, but received %s." % (type(axis)) @@ -188,7 +190,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None): flatten = True axis = 0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.argmax(x, axis, keepdim, flatten, var_dtype) else: helper = LayerHelper("argmax", **locals()) @@ -261,7 +263,9 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): print(out4) # [[1, 1, 1, 2]] """ - if axis is not None and not isinstance(axis, (int, Variable)): + if axis is not None and not isinstance( + axis, (int, Variable, paddle.pir.OpResult) + ): raise TypeError( "The type of 'axis' must be int or Tensor or None in argmin, but received %s." % (type(axis)) @@ -278,7 +282,7 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None): flatten = True axis = 0 - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.argmin(x, axis, keepdim, flatten, var_dtype) else: helper = LayerHelper("argmin", **locals()) diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 75057ed8accdbd..1c58242e877cec 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -37,6 +37,7 @@ from .layers_utils import padding_to_same_structure # noqa: F401 from .layers_utils import assert_same_structure # noqa: F401 from .layers_utils import get_shape_tensor_inputs # noqa: F401 +from .layers_utils import get_pir_shape_tensor # noqa: F401 from .layers_utils import convert_shape_to_list # noqa: F401 from .layers_utils import check_shape # noqa: F401 from .layers_utils import try_set_static_shape_tensor # noqa: F401 diff --git a/python/paddle/utils/layers_utils.py b/python/paddle/utils/layers_utils.py index d6de149dbd148b..88d19c37988749 100644 --- a/python/paddle/utils/layers_utils.py +++ b/python/paddle/utils/layers_utils.py @@ -21,7 +21,13 @@ import paddle from ..base.data_feeder import check_dtype, convert_dtype -from ..base.framework import Block, Variable, in_dygraph_mode +from ..base.framework import ( + Block, + Variable, + _current_expected_place, + core, + in_dygraph_mode, +) def convert_to_list(value, n, name, dtype=int): @@ -68,7 +74,9 @@ def convert_to_list(value, n, name, dtype=int): + str(value) ) for single_value in value_list: - assert not isinstance(single_value, Variable), ( + assert not isinstance( + single_value, (Variable, paddle.pir.OpResult) + ), ( "Required numerical type with '%s', but received Tensor." % dtype ) @@ -378,6 +386,22 @@ def _contain_var(list_or_tuple): return False +def get_pir_shape_tensor(list_shape, place=_current_expected_place()): + shape_tensor_list = [] + for dim in list_shape: + if isinstance(dim, paddle.pir.OpResult): + dim.stop_gradient = True + if convert_dtype(dim.dtype) != 'int32': + dim = paddle.cast(x=dim, dtype='int32') + if dim.shape == []: + dim = paddle.reshape(dim, [-1]) + shape_tensor_list.append(dim) + else: + temp_out = paddle.full([1], dim, core.DataType.INT32, place) + shape_tensor_list.append(temp_out) + return shape_tensor_list + + def get_shape_tensor_inputs(inputs, attrs, shape, op_type): from paddle.tensor import fill_constant diff --git a/python/setup.py.in b/python/setup.py.in index 4f2bce4bfbaada..10cbd7d54a86d0 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -430,6 +430,13 @@ packages=['paddle', 'paddle.jit', 'paddle.jit.dy2static', 'paddle.jit.newir_dy2static', + 'paddle.jit.sot', + 'paddle.jit.sot.opcode_translator', + 'paddle.jit.sot.opcode_translator.executor', + 'paddle.jit.sot.opcode_translator.executor.variables', + 'paddle.jit.sot.opcode_translator.instruction_utils', + 'paddle.jit.sot.symbolic', + 'paddle.jit.sot.utils', 'paddle.inference', 'paddle.inference.contrib', 'paddle.inference.contrib.utils', diff --git a/setup.py b/setup.py index 221e0a0770e062..e12d676cb8a5f2 100644 --- a/setup.py +++ b/setup.py @@ -1425,6 +1425,13 @@ def get_setup_parameters(): 'paddle.jit', 'paddle.jit.dy2static', 'paddle.jit.newir_dy2static', + 'paddle.jit.sot', + 'paddle.jit.sot.opcode_translator', + 'paddle.jit.sot.opcode_translator.executor', + 'paddle.jit.sot.opcode_translator.executor.variables', + 'paddle.jit.sot.opcode_translator.instruction_utils', + 'paddle.jit.sot.symbolic', + 'paddle.jit.sot.utils', 'paddle.inference', 'paddle.inference.contrib', 'paddle.inference.contrib.utils', diff --git a/test/amp/test_amp_decorate.py b/test/amp/test_amp_decorate.py index f956d37c63b39c..13b3b7fdd4d0f6 100644 --- a/test/amp/test_amp_decorate.py +++ b/test/amp/test_amp_decorate.py @@ -78,6 +78,44 @@ def forward(self, inputs): return x +class LayerNorm2D(paddle.nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(args, *kwargs) + + def forward(self, x): + x = x.transpose([0, 2, 3, 1]) + x = super().forward(x) + return x.transpose([0, 3, 1, 2]) + + +class CustomLayer(paddle.nn.Layer): + def __init__( + self, input_channel, hidden_size, fp16_conv=True, fp16_linear=True + ): + super().__init__() + self.conv = ConvBNLayer(input_channel, 8, 3) + self.linear = paddle.nn.Linear(8, hidden_size) + self.layernorm = paddle.nn.Sequential( + LayerNorm2D(hidden_size), + LayerNorm2D(hidden_size), + ) + self.fp16_conv = fp16_conv + self.fp16_linear = fp16_linear + + def forward(self, inputs): + with paddle.amp.auto_cast(enable=self.fp16_conv): + if not self.fp16_conv: + inputs = inputs.astype('float32') + x = self.conv(inputs) + with paddle.amp.auto_cast(enable=self.fp16_linear): + if not self.fp16_linear: + x = x.astype('float32') + x = self.linear(x) + x = F.relu(x) + x = self.layernorm(x) + return x + + @unittest.skipIf( not core.is_compiled_with_cuda() or paddle.device.cuda.get_device_capability()[0] < 7.0, @@ -167,6 +205,22 @@ def test_excluded_layers_attr_none(self): fp16_layers=[model.conv._conv, model.linear], ) + def test_excluded_layers_custom_layer(self): + if not paddle.amp.is_float16_supported(): + return + model = CustomLayer(4, 8) + model = paddle.amp.decorate( + models=model, + level='O2', + dtype='bfloat16', + excluded_layers=[paddle.nn.LayerNorm, paddle.nn.BatchNorm], + ) + with paddle.amp.auto_cast(level='O2'): + out = model(paddle.rand(shape=[2, 4, 8, 8], dtype='float32')) + self.check_results( + fp32_layers=[model.layernorm, model.conv._batch_norm], + ) + if __name__ == '__main__': unittest.main() diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 52ba882bc3e2a6..8700ab2e070744 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -119,6 +119,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_semi_auto_parallel_single_strategy) set_tests_properties(test_semi_auto_parallel_single_strategy PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_semi_auto_parallel_hybrid_strategy MODULES + test_semi_auto_parallel_hybrid_strategy) + set_tests_properties(test_semi_auto_parallel_hybrid_strategy + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_gpt_with_newir MODULES test_gpt_with_newir) set_tests_properties(test_gpt_with_newir PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) diff --git a/test/auto_parallel/semi_auto_parallel_for_elementwise.py b/test/auto_parallel/semi_auto_parallel_for_elementwise.py index b7e3e30b89e562..24bf0c8be9e88b 100644 --- a/test/auto_parallel/semi_auto_parallel_for_elementwise.py +++ b/test/auto_parallel/semi_auto_parallel_for_elementwise.py @@ -18,6 +18,7 @@ import paddle import paddle.distributed as dist +import paddle.nn.functional as F class TestElementwiseApiForSemiAutoParallel: @@ -27,17 +28,34 @@ def __init__(self): self._seed = eval(os.getenv("seed")) self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + paddle.seed(self._seed) + np.random.seed(self._seed) + def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + def test_unary_body(self, x_shape, out_shape, x_specs, unary_func): + x = paddle.randn(x_shape, self._dtype) + x.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + + dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) + dist_x.stop_gradient = False + + dist_out = unary_func(dist_x) + out = unary_func(x) + self.check_tensor_eq(out, dist_out) + + dist_out.backward() + out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + def test_binary_body( self, x_shape, y_shape, out_shape, x_specs, y_specs, binary_func ): - paddle.seed(self._seed) - np.random.seed(self._seed) - x = paddle.randn(x_shape, self._dtype) y = paddle.randn(y_shape, self._dtype) x.stop_gradient = False @@ -129,6 +147,22 @@ def test_sub_x_y_shard_broadcast(self): binary_func=paddle.subtract, ) + def test_square_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=paddle.square, + ) + + def test_relu_x_shard(self): + self.test_unary_body( + x_shape=[4, 16], + out_shape=[4, 16], + x_specs=['x', None], + unary_func=F.relu, + ) + def run_test_case(self): if self._backend == "cpu": paddle.set_device("cpu") @@ -141,6 +175,10 @@ def run_test_case(self): self.test_add_x_shard_broadcast() self.test_add_x_y_shard() self.test_add_x_y_shard_broadcast() + self.test_sub_x_shard() + self.test_sub_x_y_shard_broadcast() + self.test_square_x_shard() + self.test_relu_x_shard() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py index b83d9ffb87e7e8..3ca9baac5b5082 100644 --- a/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py +++ b/test/auto_parallel/semi_auto_parallel_for_replicated_spmd.py @@ -18,7 +18,6 @@ import paddle import paddle.distributed as dist -import paddle.nn.functional as F class TestReplicatedSPmdApiForSemiAutoParallel: @@ -50,28 +49,25 @@ def create_local_and_dist_tensor_pair(self, np_array, sharding_specs): return local_t, dist_t # input: phi::Tensor - # output: phi::Tensor - def test_relu(self): - x = np.random.random(size=[4, 4]).astype(self._dtype) + # output: std::vector<phi::Tensor> + def test_unbind(self): + x = np.random.random(size=[2, 8]).astype("float32") local_in, dist_in = self.create_local_and_dist_tensor_pair( x, ['x', None] ) - local_out = F.relu(local_in) - dist_out = F.relu(dist_in) - np.testing.assert_equal( - dist_out.dist_attr.dims_mapping, [-1, -1], verbose=True - ) - self.check_tensor_eq(local_out, dist_out) + local_out1, local_out2 = paddle.unbind(local_in, axis=0) + dist_out1, dist_out2 = paddle.unbind(dist_in, axis=0) + self.check_tensor_eq(local_out1, dist_out1) + self.check_tensor_eq(local_out2, dist_out2) + + local_out = paddle.add(local_out1, local_out2) + dist_out = paddle.add(dist_out1, dist_out2) - # test backward local_out.backward() dist_out.backward() - np.testing.assert_equal(dist_in.grad._local_shape, [2, 4], verbose=True) - np.testing.assert_equal( - dist_in.grad.dist_attr.dims_mapping, [0, -1], verbose=True - ) self.check_tensor_eq(local_in.grad, dist_in.grad) + # mutiple operators def test_mse_loss(self): x = np.random.random(size=[4, 4]).astype(self._dtype) y = np.random.random(size=[4]).astype(self._dtype) @@ -104,8 +100,8 @@ def run_test_case(self): else: raise ValueError("Only support cpu or gpu backend.") - self.test_relu() self.test_mse_loss() + self.test_unbind() if __name__ == '__main__': diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py b/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py new file mode 100644 index 00000000000000..90532a647812ad --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_hybrid.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + CLASS_NUM, + IMAGE_SIZE, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class DPAndMPDemoNet(nn.Layer): + def __init__(self, np_w0, np_w1, mesh): + super().__init__() + self.mesh = mesh + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="dmp_demo_weight_1", + initializer=paddle.nn.initializer.Assign(np_w0), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'y']), + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="dmp_nemo_weight_2", + initializer=paddle.nn.initializer.Assign(np_w1), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['y', None]), + ) + + def forward(self, x): + y = paddle.matmul( + dist.shard_tensor( + x, + dist_attr=dist.DistAttr( + mesh=self.mesh, sharding_specs=['x', None] + ), + ), + self.w0, + ) + z = paddle.matmul(y, self.w1) + return z + + +class TestSimpleNetHybridStrategyForSemiAutoParallel( + TestSimpleNetForSemiAutoParallel +): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) + + paddle.set_device(self._backend) + + self.init_input_data() + self.init_single_card_net_result() + + def test_dp_mp_demo_net(self): + ( + self.dp_mp_loss, + self.dp_mp_w0_grad, + self.dp_mp_w1_grad, + ) = self.run_dynamic(DPAndMPDemoNet(self.w0, self.w1, self._mesh)) + self.check_tensor_eq(self.dp_mp_loss, self.base_loss) + self.check_tensor_eq(self.dp_mp_w0_grad, self.base_w0_grad) + self.check_tensor_eq(self.dp_mp_w1_grad, self.base_w1_grad) + + def run_test_case(self): + self.test_dp_mp_demo_net() + + +if __name__ == '__main__': + TestSimpleNetHybridStrategyForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index cf034e33678aa1..c1f7e895e0486f 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -18,6 +18,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_default_data_parallel_rule MODULES test_default_data_parallel_rule) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) + py_test_modules(test_flatten_rule MODULES test_flatten_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_flatten_rule.py b/test/auto_parallel/spmd_rules/test_flatten_rule.py new file mode 100644 index 00000000000000..599b2ddf4bf958 --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_flatten_rule.py @@ -0,0 +1,398 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from collections import OrderedDict + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestFlattenSPMDRule(unittest.TestCase): + def setUp(self): + self.rule = core.get_phi_spmd_rule("flatten") + + x_shape = [8, 16, 8, 24] + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + x_tensor_dist_attr = TensorDistAttr() + x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1] + x_tensor_dist_attr.process_mesh = process_mesh + self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr) + self.attrs = OrderedDict() + + def test_flatten_infer_forward(self): + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, 1] [ 0, -1, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, 1] [ -1, 0, 1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] + # dims_mapping: [-1, -1, 1, 0] --> [-1, -1, -1, 0] [ -1, -1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 1, 0]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, 0] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] + # dims_mapping: [-1, 0, 1, -1] --> [-1, -1, -1, -1] [ -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] + # dims_mapping: [0, -1, -1, 1] --> [0, -1, -1, -1] [ 0] + self.x_dist_tensor_spec.set_dims_mapping([0, -1, -1, 1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] + # dims_mapping: [1, 0, -1, -1] --> [1, -1, -1, -1] [ 1] + self.x_dist_tensor_spec.set_dims_mapping([1, 0, -1, -1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] + # dims_mapping: [-1, -1, 0, 1] --> [-1, -1, -1, -1] [-1, -1] + self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] + # dims_mapping: [-1, 0, -1, 1] --> [-1, 0, -1, -1] [-1, 0] + self.x_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] + # dims_mapping: [0, 1, -1, -1] --> [0, 1, -1, -1] [0, 1] + self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + def test_flatten_infer_backward(self): + process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + + output_tensor_dist_attr = TensorDistAttr() + output_tensor_dist_attr.dims_mapping = [-1, -1, -1] + output_tensor_dist_attr.process_mesh = process_mesh + self.output_dist_tensor_spec = DistTensorSpec( + [8, 16 * 8, 24], output_tensor_dist_attr + ) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] (input --> output) + # dims_mapping: [0, -1, 1] --> [0, -1, -1, 1], [0, -1, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([0, -1, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual(len(infered_input_dist_attrs), 1) + self.assertEqual(len(infered_output_dist_attrs), 1) + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1, 1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] (input --> output) + # dims_mapping: [0, 1, -1] --> [0, 1, -1, -1], [0, 1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8, 24] (input --> output) + # dims_mapping: [-1, 0, 1] --> [-1, 0, -1, 1], [-1, 0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8, 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = 2 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, 0, -1, 1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] (input --> output) + # dims_mapping: [-1] --> [-1, -1, -1, -1], [-1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8 * 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([-1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] (input --> output) + # dims_mapping: [0] --> [0, -1, -1, -1], [0] (output --> input, output) + self.output_dist_tensor_spec.shape = [8 * 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([0]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0]) + + # shape: [8, 16, 8, 24] --> [8 * 16 * 8 * 24] (input --> output) + # dims_mapping: [1] --> [1, -1, -1, -1], [1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8 * 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([1]) + self.attrs['start_axis'] = 0 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] (input --> output) + # dims_mapping: [-1, -1] --> [-1, -1, -1, -1], [-1, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([-1, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] (input --> output) + # dims_mapping: [0, -1] --> [0, -1, -1, -1], [0, -1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([0, -1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, -1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, -1]) + + # shape: [8, 16, 8, 24] --> [8, 16 * 8 * 24] (input --> output) + # dims_mapping: [0, 1] --> [0, 1, -1, -1], [0, 1] (output --> input, output) + self.output_dist_tensor_spec.shape = [8, 16 * 8 * 24] + self.output_dist_tensor_spec.set_dims_mapping([0, 1]) + self.attrs['start_axis'] = 1 + self.attrs['stop_axis'] = -1 + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.output_dist_tensor_spec, + self.attrs['start_axis'], + self.attrs['stop_axis'], + ) + infered_input_dist_attrs = result_dist_attrs[0] + infered_output_dist_attrs = result_dist_attrs[1] + + self.assertEqual( + infered_input_dist_attrs[0].dims_mapping, [0, 1, -1, -1] + ) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [0, 1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_api_dist_branch.py b/test/auto_parallel/test_api_dist_branch.py index 8880aac9d261ba..dbeec8ec362220 100644 --- a/test/auto_parallel/test_api_dist_branch.py +++ b/test/auto_parallel/test_api_dist_branch.py @@ -136,23 +136,6 @@ def test_broadcast_tensors_for_dist_tensor(self): self.check_tensor_eq(local_in1.grad, dist_in1.grad) self.check_tensor_eq(local_in2.grad, dist_in2.grad) - # input: phi::Tensor - # output: std::vector<phi::Tensor> - def test_unbind_for_dist_tensor(self): - x = np.random.random(size=[2, 8]).astype("float32") - local_in, dist_in = self.create_local_and_dist_tensor_pair(x) - local_out1, local_out2 = paddle.unbind(local_in, axis=0) - dist_out1, dist_out2 = paddle.unbind(dist_in, axis=0) - self.check_tensor_eq(local_out1, dist_out1) - self.check_tensor_eq(local_out2, dist_out2) - - local_out = paddle.concat([local_out1, local_out2]) - dist_out = paddle.concat([dist_out1, dist_out2]) - - local_out.backward() - dist_out.backward() - self.check_tensor_eq(local_in.grad, dist_in.grad) - # input: paddle::optional<phi::Tensor> # output: phi::Tensor def test_expand_as_for_dist_tensor(self): diff --git a/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py new file mode 100644 index 00000000000000..eefc47d6967163 --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_hybrid_strategy.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelHybridStrategy(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120, nnode=2) + self._default_envs = { + "dtype": "float32", + "seed": "2023", + } + # this test need to be run on 4-cards environment, but our CI only supports + # 2-cards distribute test, so skip gpu test now + self._changeable_envs = {"backend": ["cpu"]} + + def test_simple_net_bybrid_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_hybrid.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/cinn/ir/test_llir_schedule_bind.py b/test/cinn/ir/test_llir_schedule_bind.py new file mode 100644 index 00000000000000..5be0ddf95ae172 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_bind.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_bind_reduce(): + @to_cinn_llir + def reduce_sum(A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256))): + for i1 in range(1): + for j1 in range(4): + for k1 in range(256): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = 0.0 + for l1 in range(512): + with ir.ScheduleBlockContext("B"): + sch.bind(i1, "blockIdx.x") + sch.bind(j1, "threadIdx.y") + sch.bind(k1, "threadIdx.x") + vi1, vj1, vk1, vl1 = ir.AxisMap( + "SSSR", [i1, j1, k1, l1] + ) + B[vi1, vj1, vk1] = ( + B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] + ) + + @to_cinn_llir + def reduce_sum_expected( + A: DataArray((1, 4, 256, 512)), B: DataArray((1, 4, 256)) + ): + for i1 in range(1): + for j1 in range(4): + for k1 in range(256): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = 0.0 + for l1 in range(512): + with ir.ScheduleBlockContext("B"): + vi1, vj1, vk1, vl1 = ir.AxisMap( + "SSSR", [i1, j1, k1, l1] + ) + B[vi1, vj1, vk1] = ( + B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] + ) + sch.bind(init.i1, "blockIdx.x") + sch.bind(init.j1, "threadIdx.y") + sch.bind(init.k1, "threadIdx.x") + + assert_llir_equal(reduce_sum, reduce_sum_expected) + + +if __name__ == "__main__": + test_bind_reduce() diff --git a/test/cinn/ir/test_llir_schedule_for_kind.py b/test/cinn/ir/test_llir_schedule_for_kind.py new file mode 100644 index 00000000000000..70dc96ea0715de --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_for_kind.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test.cinn.utils.testing import assert_llir_equal + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +# Current Python DSL cannot express the parallel `for`, +# only checks that it can be converted correctly +def test_elementwise_parallel(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("Y"): + i1, j1 = ir.AxisMap("SS", [i, j]) + Y[i1, j1] = A[i1, j1] + 2.0 + sch.parallel(A_block.i) + + assert_llir_equal(elementwise_add, elementwise_add) + + +# Current Python DSL cannot express the vectorize `for`, +# only checks that it can be converted correctly +def test_elementwise_vectorize(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j0 in range(32): + for j1 in range(4): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1]) + Y[i1, j1] = A[i1, j1] + 2.0 + sch.vectorize(Y_block.j1, 1) + + assert_llir_equal(elementwise_add, elementwise_add) + + +# Current Python DSL cannot express the unroll `for`, +# only checks that it can be converted correctly +def test_elementwise_unroll(): + @to_cinn_llir + def elementwise_add( + X: DataArray((128, 128)), + Y: DataArray((128, 128)), + A: DataArray((128, 128)), + ): + for i in range(128): + for j in range(128): + with ir.ScheduleBlockContext("A") as A_block: + i1, j1 = ir.AxisMap("SS", [i, j]) + A[i1, j1] = X[i1, j1] * 2.0 + for i in range(128): + for j0 in range(32): + for j1 in range(4): + with ir.ScheduleBlockContext("Y") as Y_block: + i1, j1 = ir.AxisMap("SS", [i, j0 * 4 + j1]) + Y[i1, j1] = A[i1, j1] + 2.0 + sch.unroll(Y_block.j1) + + assert_llir_equal(elementwise_add, elementwise_add) + + +if __name__ == "__main__": + test_elementwise_parallel() + test_elementwise_vectorize() + test_elementwise_unroll() diff --git a/test/cinn/ir/test_llir_schedule_rfactor.py b/test/cinn/ir/test_llir_schedule_rfactor.py new file mode 100644 index 00000000000000..098435686c7915 --- /dev/null +++ b/test/cinn/ir/test_llir_schedule_rfactor.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +def test_matmul(): + @to_cinn_llir + def matmul( + A: DataArray((128, 128)), + B: DataArray((128, 128)), + C: DataArray((128, 128)), + ): + for i0 in range(128): + for i1 in range(128): + with ir.ScheduleBlockContext("init"): + vi, vj = ir.AxisMap("SS", [i0, i1]) + C[vi, vj] = 0.0 + for i2_outer in range(4): + for i2_inner_outer in range(8): + for i2_inner_inner in range(4): + with ir.ScheduleBlockContext( + "compute" + ) as Compute_block: + vi, vj, vk = ir.AxisMap( + "SSR", + [ + i0, + i1, + i2_outer * 32 + + i2_inner_outer * 4 + + i2_inner_inner, + ], + ) + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + sch.rfactor(Compute_block.i2_inner_inner, 0) + + # TODO(6clc): rfactor schedule rasie Error Message: iter_value not support complex reduce bindings + # assert_llir_equal(matmul, matmul) + + +if __name__ == "__main__": + test_matmul() diff --git a/test/cinn/runtime/test_launch.py b/test/cinn/runtime/test_launch.py new file mode 100644 index 00000000000000..bb8e3d45aeee5c --- /dev/null +++ b/test/cinn/runtime/test_launch.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cinn +import numpy as np +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray + + +@to_cinn_llir +def bin_op_kernel(X, Y, Z, N): + for idx in range(N): + with ir.ScheduleBlockContext("Z"): + idx1 = ir.AxisMap("S", [idx]) + Z[idx1] = X[idx1] + Y[idx1] + + +def test_launch_fp32(): + N = 10 + X_np = np.random.random(N).astype(np.float32) + Y_np = np.random.random(N).astype(np.float32) + Z_np = np.zeros((N), dtype=np.float32) + target = cinn.common.DefaultNVGPUTarget() + X = DataArray.from_numpy(X_np, target) + Y = DataArray.from_numpy(Y_np, target) + Z = DataArray.from_numpy(Z_np, target) + + # compile and run + bin_op_kernel[target](X, Y, Z, N) + pred = Z.to_numpy() + gt = np.add(X_np, Y_np) + np.testing.assert_allclose(pred, gt) + + +def test_launch_dtype(): + for np_dtype in ( + np.uint16, # convert np.uint16 to bfloat16 in Paddle and CINN + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint32, + np.uint64, + ): + N = 10 + X_np = np.random.random(N).astype(np_dtype) + Y_np = np.random.random(N).astype(np_dtype) + Z_np = np.zeros((N), dtype=np_dtype) + target = cinn.common.DefaultNVGPUTarget() + X = DataArray.from_numpy(X_np, target) + Y = DataArray.from_numpy(Y_np, target) + Z = DataArray.from_numpy(Z_np, target) + + # compile and run + bin_op_kernel[target](X, Y, Z, N) + pred = Z.to_numpy() + + +if __name__ == "__main__": + test_launch_fp32() + test_launch_dtype() diff --git a/test/cinn/runtime/test_reduce_cuda.py b/test/cinn/runtime/test_reduce_cuda.py new file mode 100644 index 00000000000000..3eaf160763bd49 --- /dev/null +++ b/test/cinn/runtime/test_reduce_cuda.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cinn +import numpy as np +from cinn import ir, to_cinn_llir +from cinn.runtime.data_array import DataArray +from cinn.schedule import IRSchedule as sch + + +@to_cinn_llir +def reduce_max(A, B): + for i1 in range(1): + for j1 in range(2): + for k1 in range(4): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = 0.0 + for l1 in range(8): + with ir.ScheduleBlockContext("B"): + sch.bind(i1, "blockIdx.x") + sch.bind(j1, "threadIdx.y") + sch.bind(k1, "threadIdx.x") + vi1, vj1, vk1, vl1 = ir.AxisMap( + "SSSR", [i1, j1, k1, l1] + ) + B[vi1, vj1, vk1] = ir.Max.make( + B[vi1, vj1, vk1], A[vi1, vj1, vk1, vl1] + ) + + +@to_cinn_llir +def reduce_sum(A, B): + for i1 in range(1): + for j1 in range(2): + for k1 in range(4): + with ir.ScheduleBlockContext("init") as init: + vi, vj, vk = ir.AxisMap("SSS", [i1, j1, k1]) + B[vi, vj, vk] = 0.0 + for l1 in range(8): + with ir.ScheduleBlockContext("B"): + sch.bind(i1, "blockIdx.x") + sch.bind(j1, "threadIdx.y") + sch.bind(k1, "threadIdx.x") + vi1, vj1, vk1, vl1 = ir.AxisMap( + "SSSR", [i1, j1, k1, l1] + ) + B[vi1, vj1, vk1] = ( + B[vi1, vj1, vk1] + A[vi1, vj1, vk1, vl1] + ) + + +def test_reduce_max_cuda(): + # prepare input and output array + d1 = 2 + d2 = 4 + d3 = 8 + a_np = np.random.rand(1, d1, d2, d3).astype("float32") + b_np = a_np.max(axis=-1).astype("float32") + target = cinn.common.DefaultNVGPUTarget() + a = DataArray.from_numpy(a_np, target) + b = DataArray.from_numpy(np.zeros_like(b_np), target) + reduce_max[target](a, b) + np.testing.assert_allclose(b.to_numpy(), b_np, rtol=1e-5, atol=1e-6) + + +def test_reduce_sum_cuda(): + # prepare input and output array + d1 = 2 + d2 = 4 + d3 = 8 + a_np = np.random.rand(1, d1, d2, d3).astype("float32") + b_np = a_np.sum(axis=-1).astype("float32") + target = cinn.common.DefaultNVGPUTarget() + a = DataArray.from_numpy(a_np, target) + b = DataArray.from_numpy(np.zeros_like(b_np), target) + reduce_sum[target](a, b) + np.testing.assert_allclose(b.to_numpy(), b_np, rtol=1e-5, atol=1e-6) + + +if __name__ == "__main__": + test_reduce_max_cuda() + test_reduce_sum_cuda() diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index 309acb6164007d..b1b57cb6cf4f5e 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -297,6 +297,20 @@ if(WITH_NCCL) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT "300") endif() endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_dygraph_dataparallel_bf16 + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "200" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22024;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_dygraph_dataparallel_bf16 PROPERTIES TIMEOUT "200") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_dygraph_sharding_stage2 @@ -326,6 +340,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) set_tests_properties(test_dygraph_sharding_stage2_bf16 PROPERTIES TIMEOUT "200") endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_dygraph_sharding_stage1_fp16 + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "200" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22024;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_dygraph_sharding_stage1_fp16 PROPERTIES TIMEOUT + "200") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_parallel_dygraph_control_flow @@ -680,11 +709,6 @@ if((WITH_GPU OR WITH_XPU) AND (LINUX OR WIN32)) test_fleet_recompute_meta_optimizer ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") endif() -if(LOCAL_ALL_ARCH AND (LINUX OR WIN32)) - py_test_modules( - test_fleet_private_function MODULES test_fleet_private_function ENVS - "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") -endif() if((WITH_GPU OR WITH_XPU) AND LOCAL_ALL_PLAT) bash_test_modules( test_new_group diff --git a/test/collective/fleet/c_comm_init_op.py b/test/collective/fleet/c_comm_init_op.py index 988c0fcc27954b..15230b9b71f331 100644 --- a/test/collective/fleet/c_comm_init_op.py +++ b/test/collective/fleet/c_comm_init_op.py @@ -17,9 +17,6 @@ import paddle from paddle import base -from paddle.distributed.fleet.base.private_helper_function import ( - wait_server_ready, -) paddle.enable_static() @@ -35,8 +32,6 @@ def setUp(self): self.exe = base.Executor(self.place) self.endpoints.remove(self.current_endpoint) self.other_endpoints = self.endpoints - if self.rank == 0: - wait_server_ready(self.other_endpoints) def test_specifying_devices(self): program = base.Program() diff --git a/test/collective/fleet/dygraph_dataparallel_bf16.py b/test/collective/fleet/dygraph_dataparallel_bf16.py new file mode 100644 index 00000000000000..efc7b6f993d987 --- /dev/null +++ b/test/collective/fleet/dygraph_dataparallel_bf16.py @@ -0,0 +1,198 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients, +) +from paddle.nn import Linear, ReLU + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + + def __getitem__(self, idx): + img = np.random.rand(self.linear_size).astype('float32') + return img + + def __len__(self): + return self.num_samples + + +def optimizer_setting(model, use_pure_bf16, use_main_grad): + if use_main_grad: + assert use_pure_bf16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_bf16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def train_mlp( + model, use_pure_bf16=False, use_main_grad=False, accumulate_grad=False +): + optimizer = optimizer_setting( + model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad + ) + if use_pure_bf16: + level = 'O2' + custom_white_list = None + model = paddle.amp.decorate( + models=model, + dtype="bfloat16", + level=level, + ) + else: + level = 'O1' + custom_white_list = [ + "matmul_v2", + "elementwise_add", + "relu", + "reduce_mean", + ] + model = paddle.DataParallel(model) + + paddle.seed(2023) + np.random.seed(2023) + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) + if not use_pure_bf16: + for param in model.parameters(): + t = paddle.cast( + paddle.cast(param, dtype='bfloat16'), dtype='float32' + ) + param.set_value(t) + + losses = [] + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + data.stop_gradient = True + + with model.no_sync(): + with paddle.amp.auto_cast( + True, + level=level, + dtype="bfloat16", + custom_white_list=custom_white_list, + ): + out = model(data) + loss = paddle.mean(out) + + losses.append(loss) + + loss.backward() + + if not accumulate_grad: + fused_allreduce_gradients(list(model.parameters()), None) + + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + fused_allreduce_gradients(list(model.parameters()), None) + + optimizer.step() + optimizer.clear_grad() + + return losses + + +def test_dp_bf16(): + if not paddle.amp.is_bfloat16_supported(): + return + paddle.distributed.init_parallel_env() + mlp = MLP() + state_dict = mlp.state_dict() + + # dp bf16 O1 vs dp bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + losses_o1 = train_mlp(mlp1, use_pure_bf16=False) + losses_o2 = train_mlp(mlp2, use_pure_bf16=True, use_main_grad=True) + for i in range(len(losses_o2)): + loss_o2 = paddle.cast(losses_o2[i], dtype='float32').detach() + loss_o1 = paddle.cast(losses_o1[i], dtype='float32').detach() + np.testing.assert_array_equal(loss_o2, loss_o1) + + # grad accumulation test + mlp3 = MLP() + mlp4 = MLP() + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + losses_acc_grad_o1 = train_mlp( + mlp3, use_pure_bf16=False, accumulate_grad=True + ) + losses_acc_grad_o2 = train_mlp( + mlp4, use_pure_bf16=True, use_main_grad=True, accumulate_grad=True + ) + for i in range(len(losses_acc_grad_o2)): + loss_acc_grad_o2 = paddle.cast( + losses_acc_grad_o2[i], dtype='float32' + ).detach() + loss_acc_grad_o1 = paddle.cast( + losses_acc_grad_o1[i], dtype='float32' + ).detach() + np.testing.assert_array_equal(loss_acc_grad_o2, loss_acc_grad_o1) + + +if __name__ == '__main__': + test_dp_bf16() diff --git a/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py b/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py new file mode 100644 index 00000000000000..601659e0fb98b9 --- /dev/null +++ b/test/collective/fleet/dygraph_group_sharded_stage1_fp16.py @@ -0,0 +1,263 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.nn import Linear, ReLU + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + + def __getitem__(self, idx): + img = np.random.rand(self.linear_size).astype('float32') + return img + + def __len__(self): + return self.num_samples + + +def optimizer_setting(model, use_pure_fp16, use_main_grad): + if use_main_grad: + assert use_pure_fp16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="float16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_fp16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def train_mlp( + model, + sharding_stage, + use_pure_fp16=False, + accumulate_grad=False, + use_main_grad=False, + test_scaler=False, + scale_loss=1024, +): + scaler = None + if test_scaler: + assert sharding_stage == 1 + assert not accumulate_grad + scaler = paddle.amp.GradScaler(init_loss_scaling=scale_loss) + scaler = fleet.distributed_scaler(scaler) + optimizer = optimizer_setting( + model=model, use_pure_fp16=use_pure_fp16, use_main_grad=use_main_grad + ) + if use_pure_fp16: + level = 'O2' + custom_white_list = None + model = paddle.amp.decorate(models=model, dtype="float16", level=level) + else: + level = 'O1' + custom_white_list = [ + "matmul_v2", + "elementwise_add", + "relu", + "reduce_mean", + ] + + if sharding_stage == 1: + optimizer = fleet.distributed_optimizer(optimizer) + + model = fleet.distributed_model(model) + else: + model = paddle.DataParallel(model) + + paddle.seed(2023) + np.random.seed(2023) + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) + + if sharding_stage == 1: + model.to(device="gpu") + + if not use_pure_fp16: + for param in model.parameters(): + t = paddle.cast( + paddle.cast(param, dtype='float16'), dtype='float32' + ) + param.set_value(t) + + losses = [] + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + data.stop_gradient = True + + with paddle.amp.auto_cast( + True, + level=level, + dtype="float16", + custom_white_list=custom_white_list, + ): + out = model(data) + loss = paddle.mean(out) + + losses.append(loss) + + if test_scaler: + assert scaler is not None + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + else: + loss.backward() + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + return losses + + +def test_stage1_fp16(): + if not paddle.amp.is_float16_supported(): + return + paddle.distributed.init_parallel_env() + + strategy = fleet.DistributedStrategy() + hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + scale_loss = 1024 + amp_configs = {"init_loss_scaling": scale_loss, "use_pure_fp16": True} + strategy.hybrid_configs = hybrid_configs + strategy.amp_configs = amp_configs + + fleet.init(is_collective=True, strategy=strategy) + mlp = MLP() + state_dict = mlp.state_dict() + + # stage1 fp16 O1 vs stage1 fp16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + o1_losses = train_mlp( + mlp1, + sharding_stage=1, + use_pure_fp16=False, + scale_loss=scale_loss, + ) + o2_losses = train_mlp( + mlp2, + sharding_stage=1, + use_pure_fp16=True, + use_main_grad=True, + scale_loss=scale_loss, + ) + for i in range(len(o1_losses)): + o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() + o2_32_loss = paddle.cast(o2_losses[i], dtype='float32').detach() + np.testing.assert_array_equal(o1_32_loss, o2_32_loss) + + # stage1 scaler test + mlp3 = MLP() + mlp3.set_state_dict(state_dict) + train_mlp( + mlp3, + sharding_stage=1, + use_pure_fp16=True, + use_main_grad=True, + test_scaler=True, + scale_loss=scale_loss, + ) + + # grad accumulation test + mlp5 = MLP() + mlp6 = MLP() + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + o1_losses_grad_acc = train_mlp( + mlp5, + sharding_stage=1, + use_pure_fp16=False, + accumulate_grad=True, + scale_loss=scale_loss, + ) + o2_losses_grad_acc = train_mlp( + mlp6, + sharding_stage=1, + use_pure_fp16=True, + use_main_grad=True, + accumulate_grad=True, + scale_loss=scale_loss, + ) + for i in range(len(o2_losses_grad_acc)): + o2_loss_grad_acc = paddle.cast( + o2_losses_grad_acc[i], dtype='float32' + ).detach() + o1_loss_grad_acc = paddle.cast( + o1_losses_grad_acc[i], dtype='float32' + ).detach() + np.testing.assert_array_equal(o2_loss_grad_acc, o1_loss_grad_acc) + + return + + +if __name__ == '__main__': + test_stage1_fp16() diff --git a/test/collective/fleet/dygraph_group_sharded_stage2.py b/test/collective/fleet/dygraph_group_sharded_stage2.py index 66795a0d2c9be7..81f6df163f1db5 100644 --- a/test/collective/fleet/dygraph_group_sharded_stage2.py +++ b/test/collective/fleet/dygraph_group_sharded_stage2.py @@ -94,6 +94,7 @@ def train_mlp( opt_group=False, save_model=False, test_minimize=False, + scale_fn_test=False, ): if sharding_stage != "dp": group = paddle.distributed.new_group([0, 1], backend="nccl") @@ -104,6 +105,9 @@ def train_mlp( else: optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + if scale_fn_test: + assert sharding_stage == 2 + if sharding_stage == 2: optimizer = GroupShardedOptimizerStage2( params=optimizer._parameter_list, optim=optimizer, group=group @@ -112,6 +116,13 @@ def train_mlp( model = GroupShardedStage2( model, optimizer, group=group, buffer_max_size=2**21 ) + if scale_fn_test: + param = model.parameters()[0] + grad = paddle.rand(param.shape, dtype=param.dtype) + model._get_scaled_grad_fn(param)(grad) + param.grad = grad + model._get_scaled_grad_fn(param)(None) + return else: model = paddle.DataParallel(model) @@ -178,6 +189,7 @@ def test_dp_stage2(): mlp5 = MLP() mlp6 = MLP() mlp7 = MLP() + mlp8 = MLP() mlp1.set_state_dict(state_dict) mlp2.set_state_dict(state_dict) mlp3.set_state_dict(state_dict) @@ -185,6 +197,7 @@ def test_dp_stage2(): mlp5.set_state_dict(state_dict) mlp6.set_state_dict(state_dict) mlp7.set_state_dict(state_dict) + mlp8.set_state_dict(state_dict) # DP VS stage2 dp_params = train_mlp( @@ -242,6 +255,8 @@ def test_dp_stage2(): # check optimizer.minimize() error train_mlp(mlp7, sharding_stage=2, test_minimize=True) + train_mlp(mlp8, sharding_stage=2, scale_fn_test=True) + if __name__ == '__main__': test_dp_stage2() diff --git a/test/collective/fleet/test_dygraph_dataparallel_bf16.py b/test/collective/fleet/test_dygraph_dataparallel_bf16.py new file mode 100644 index 00000000000000..1401399e8fc4cf --- /dev/null +++ b/test/collective/fleet/test_dygraph_dataparallel_bf16.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphDataParallel(TestMultipleGpus): + def test_dygraph_dataparallel_bf16(self): + self.run_mnist_2gpu('dygraph_dataparallel_bf16.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_dygraph_sharding_stage1_fp16.py b/test/collective/fleet/test_dygraph_sharding_stage1_fp16.py new file mode 100644 index 00000000000000..580567d40e4f73 --- /dev/null +++ b/test/collective/fleet/test_dygraph_sharding_stage1_fp16.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingStage1(TestMultipleGpus): + # check sharding logic as well as the accuracy with single mode + def test_dygraph_sharding_stage1_fp16(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage1_fp16.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_fleet_private_function.py b/test/collective/fleet/test_fleet_private_function.py deleted file mode 100644 index c6a3a197c09ac4..00000000000000 --- a/test/collective/fleet/test_fleet_private_function.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import socket -import threading -import unittest - - -class TestFleetPrivateFunction(unittest.TestCase): - def test_wait_port(self): - def init_server(port): - import time - - time.sleep(5) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(("127.0.0.1", port)) - sock.listen(10) - while True: - c, addr = sock.accept() - c.send("0") - c.close() - break - - thr = threading.Thread(target=init_server, args=(9292,)) - thr.start() - - from paddle.distributed import fleet - - ep = ["127.0.0.1:9292"] - fleet.base.private_helper_function.wait_server_ready(ep) - - thr.join() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/collective/fleet/test_fused_attention_pass_with_mp.sh b/test/collective/fleet/test_fused_attention_pass_with_mp.sh index d00f2fdbac0e1d..4b2b48cdc08df8 100644 --- a/test/collective/fleet/test_fused_attention_pass_with_mp.sh +++ b/test/collective/fleet/test_fused_attention_pass_with_mp.sh @@ -17,4 +17,5 @@ set -e # use default values # FIXME: random fails on Unknown command lines -c (or -m). +export FLAGS_dynamic_static_unified_comm=0 CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch fused_attention_pass_with_mp.py diff --git a/test/collective/fleet/testslist.csv b/test/collective/fleet/testslist.csv index 664bb0bc8a502d..b9df9ace687cf4 100644 --- a/test/collective/fleet/testslist.csv +++ b/test/collective/fleet/testslist.csv @@ -23,8 +23,10 @@ test_pipeline,,,160,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_pro test_fleet_utils,LINUX;APPLE,,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_static_model_parallel,,,240,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_no_sync,,GPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL +test_dygraph_dataparallel_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage2,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage2_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_sharding_stage1_fp16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_control_flow,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_hybrid_parallel_inference_helper,,,120,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., @@ -59,7 +61,6 @@ test_parallel_dygraph_sparse_embedding_over_height,,ROCM,350,DIST,../../legacy_t test_distributed_strategy,LINUX;APPLE,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_auto_parallel_parallelizer,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_new_group,,GPU;XPU,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=, test_c_comm_init_op,LINUX,GPU;XPU,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=, test_fused_attention_pass_with_mp,LINUX,GPU,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=, diff --git a/test/collective/test_communication_api_base.py b/test/collective/test_communication_api_base.py index 7f80730e1ccf14..abd56bfe3d3dfa 100644 --- a/test/collective/test_communication_api_base.py +++ b/test/collective/test_communication_api_base.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import itertools import os import shutil +import socket import subprocess import sys import tempfile @@ -22,7 +24,7 @@ class CommunicationTestDistBase(unittest.TestCase): - def setUp(self, save_log_dir=None, num_of_devices=2, timeout=120): + def setUp(self, save_log_dir=None, num_of_devices=2, timeout=120, nnode=1): self._python_interp = sys.executable self._save_log_dir = save_log_dir self._log_dir = tempfile.TemporaryDirectory() @@ -31,15 +33,43 @@ def setUp(self, save_log_dir=None, num_of_devices=2, timeout=120): self._timeout = timeout self._seeds = [i + 10 for i in range(num_of_devices)] self._devices = ','.join(self._device_list) + self._nnode = nnode + self._port_set = set() + + def _find_free_port(self): + def __free_port(): + with contextlib.closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + while True: + port = __free_port() + if port not in self._port_set: + self._port_set.add(port) + return port def run_test_case(self, script_file, user_defined_envs=None): runtime_envs = os.environ if user_defined_envs is not None: runtime_envs.update(user_defined_envs) runtime_envs["CUDA_VISIBLE_DEVICES"] = self._devices - start_command = f"{self._python_interp} -u -m paddle.distributed.launch --log_dir {self._log_dir.name} --devices {self._devices} {script_file}" + if self._nnode > 1: + start_command = f"{self._python_interp} -u -m paddle.distributed.launch --nnode={self._nnode} --master=127.0.0.1:{self._find_free_port()} --log_dir {self._log_dir.name} --devices {self._devices} {script_file}" + else: + start_command = f"{self._python_interp} -u -m paddle.distributed.launch --log_dir {self._log_dir.name} --devices {self._devices} {script_file}" start_command_list = start_command.strip().split() + if self._nnode > 1: + for i in range(1, self._nnode): + p = subprocess.Popen( + start_command_list, + env=runtime_envs, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + try: self._launcher = subprocess.run( start_command_list, diff --git a/test/cpp/auto_parallel/dist_tensor_test.cc b/test/cpp/auto_parallel/dist_tensor_test.cc index 9882a4b831bb53..a94cfd37d6cc24 100644 --- a/test/cpp/auto_parallel/dist_tensor_test.cc +++ b/test/cpp/auto_parallel/dist_tensor_test.cc @@ -43,7 +43,7 @@ TEST(dist_tensor, constructor) { dist_attr.set_process_mesh(mesh); // copy construct - DenseTensor x1(alloc, meta); + std::shared_ptr<DenseTensor> x1 = std::make_shared<DenseTensor>(alloc, meta); DistTensor dist_x1(x1, dist_attr); EXPECT_TRUE(dist_x1.defined()); EXPECT_TRUE(dist_x1.initialized()); diff --git a/test/cpp/fluid/math/im2col_test.cc b/test/cpp/fluid/math/im2col_test.cc index fab3086a820f20..f3925bce958696 100644 --- a/test/cpp/fluid/math/im2col_test.cc +++ b/test/cpp/fluid/math/im2col_test.cc @@ -89,7 +89,7 @@ void testIm2col() { std::array<float, 8> out_cfo_data = {0, 1, 1, 2, 3, 4, 4, 5}; std::array<float, 8> out_ocf_data = {0, 1, 3, 4, 1, 2, 4, 5}; - float* out_cfo_ptr; + float* out_cfo_ptr = nullptr; if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output_cfo.data<float>(); } else { @@ -101,7 +101,7 @@ void testIm2col() { EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]); } - float* out_ocf_ptr; + float* out_ocf_ptr = nullptr; if (paddle::platform::is_cpu_place(*place)) { out_ocf_ptr = output_ocf.data<float>(); } else { @@ -130,7 +130,7 @@ void testIm2col() { col2im(*context, output_cfo, dilation, stride, padding, &input); - float* in_ptr; + float* in_ptr = nullptr; if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data<float>(); } else { diff --git a/test/cpp/fluid/math/vol2col_test.cc b/test/cpp/fluid/math/vol2col_test.cc index 27a873082a1191..9a6f14c3685cb2 100644 --- a/test/cpp/fluid/math/vol2col_test.cc +++ b/test/cpp/fluid/math/vol2col_test.cc @@ -91,7 +91,7 @@ void testVol2col() { std::array<float, 16> vol_2_col = { 0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; - float* out_cfo_ptr; + float* out_cfo_ptr = nullptr; if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output.data<float>(); } else { @@ -116,7 +116,7 @@ void testVol2col() { phi::funcs::Col2VolFunctor<DeviceContext, float> col2vol; col2vol(*context, output, dilations, strides, paddings, &input); - float* in_ptr; + float* in_ptr = nullptr; if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data<float>(); } else { diff --git a/test/cpp/fluid/reader/reader_blocking_queue_test.cc b/test/cpp/fluid/reader/reader_blocking_queue_test.cc index 7db47f0761853f..b02f21eb2eb499 100644 --- a/test/cpp/fluid/reader/reader_blocking_queue_test.cc +++ b/test/cpp/fluid/reader/reader_blocking_queue_test.cc @@ -40,7 +40,7 @@ void FirstInFirstOut(size_t queue_cap, size_t count = 0; while (true) { std::this_thread::sleep_for(std::chrono::milliseconds(receive_time_gap)); - size_t elem; + size_t elem = 0; if (!q.Receive(&elem)) { break; } @@ -76,7 +76,7 @@ TEST(BlockingQueue, SenderBlockingTest) { EXPECT_EQ(send_count, queue_cap); std::vector<size_t> res; while (true) { - size_t elem; + size_t elem = 0; if (!q.Receive(&elem)) { break; } @@ -93,7 +93,7 @@ TEST(BlockingQueue, ReceiverBlockingTest) { BlockingQueue<size_t> q(queue_cap); std::vector<size_t> receive_res; std::thread receiver([&]() { - size_t elem; + size_t elem = 0; while (true) { if (!q.Receive(&elem)) { break; @@ -162,7 +162,7 @@ void MultiSenderMultiReceiver(const size_t queue_cap, while (true) { std::this_thread::sleep_for( std::chrono::milliseconds(receive_time_gap)); - size_t elem; + size_t elem = 0; if (!q.Receive(&elem)) { break; } @@ -230,7 +230,7 @@ TEST(BlockingQueue, speed_test_mode) { for (size_t i = 0; i < queue_size; ++i) { q1.Send(i); } - size_t b; + size_t b = 0; for (size_t i = 0; i < queue_size; ++i) { q1.Receive(&b); EXPECT_EQ(b, i); diff --git a/test/cpp/imperative/test_gradient_accmulator.cc b/test/cpp/imperative/test_gradient_accmulator.cc index 982fd81a988358..bb264250ecf567 100644 --- a/test/cpp/imperative/test_gradient_accmulator.cc +++ b/test/cpp/imperative/test_gradient_accmulator.cc @@ -392,7 +392,7 @@ static void TestGradientAccumulatorTestUnchangeInput( int64_t maximum_row_number = 100; std::uniform_int_distribution<int64_t> dist(1, maximum_row_number); - int seed; + int seed = 0; { std::random_device rd; seed = static_cast<int>(rd()); diff --git a/test/cpp/inference/api/CMakeLists.txt b/test/cpp/inference/api/CMakeLists.txt index bbd76ca4344119..8f0b3e5c093335 100644 --- a/test/cpp/inference/api/CMakeLists.txt +++ b/test/cpp/inference/api/CMakeLists.txt @@ -969,6 +969,14 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST) paddle_inference_shared ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) + inference_analysis_test( + trt_disable_tensorrt_half_ops_test + SRCS + trt_disable_tensorrt_half_ops_test.cc + EXTRA_DEPS + paddle_inference_shared + ARGS + --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) inference_analysis_test( trt_fc_prelu_test SRCS @@ -1304,6 +1312,8 @@ if(WITH_TESTING AND WITH_INFERENCE_API_TEST) set_tests_properties(test_trt_dynamic_shape_ernie PROPERTIES TIMEOUT 480) set_tests_properties(trt_mark_trt_engine_outputs_test PROPERTIES TIMEOUT 300) + set_tests_properties(trt_disable_tensorrt_half_ops_test PROPERTIES TIMEOUT + 300) endif() if(WITH_MKLDNN) diff --git a/test/cpp/inference/api/trt_disable_tensorrt_half_ops_test.cc b/test/cpp/inference/api/trt_disable_tensorrt_half_ops_test.cc new file mode 100644 index 00000000000000..68dfd62d019026 --- /dev/null +++ b/test/cpp/inference/api/trt_disable_tensorrt_half_ops_test.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include <glog/logging.h> +#include <gtest/gtest.h> + +#include "test/cpp/inference/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(TensorRT, disable_tensorrt_half_ops) { + std::string model_dir = FLAGS_infer_model + "/resnet50"; + AnalysisConfig config; + config.SetModel(model_dir); + config.EnableUseGpu(100, 0); + config.EnableTensorRtEngine( + 1 << 30, 1, 5, AnalysisConfig::Precision::kHalf, false, false); + + paddle_infer::experimental::InternalUtils::DisableTensorRtHalfOps(&config, + {"conv2d"}); + + std::vector<std::vector<PaddleTensor>> inputs_all; + auto predictor = CreatePaddlePredictor(config); + SetFakeImageInput(&inputs_all, model_dir, false, "__model__", ""); + + std::vector<PaddleTensor> outputs; + for (auto &input : inputs_all) { + ASSERT_TRUE(predictor->Run(input, &outputs)); + predictor->ClearIntermediateTensor(); + } +} + +} // namespace inference +} // namespace paddle diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 9bdc5c3d3c718d..28a425dbd4ebe9 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -44,6 +44,7 @@ PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT); bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } @@ -279,5 +280,74 @@ TEST(StandaloneExecutor, if_op) { EXPECT_EQ(res1, true); } +using namespace paddle::dialect; // NOLINT +TEST(StandaloneExecutor, while_op) { + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::ControlFlowDialect>(); + + pir::Program program(ctx); + pir::Block* block = program.block(); + pir::Builder builder(ctx, block); + + auto i = builder + .Build<paddle::dialect::FullOp>( + std::vector<int64_t>{1}, 1, phi::DataType::INT32) + .out(); + + auto ten = builder + .Build<paddle::dialect::FullOp>( + std::vector<int64_t>{1}, 10, phi::DataType::INT32) + .out(); + + // comput condition value: i <= ten + auto cond_value = builder.Build<LessEqualOp>(i, ten).out(); + + auto while_op = + builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten}); + + // { i = i + 1} + pir::Block* body_block = while_op.body_block(); + auto body_i_argument = body_block->AddArgument(i.type()); + auto body_ten_argument = body_block->AddArgument(ten.type()); + builder.SetInsertionPointToStart(body_block); + auto one = + builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32) + .out(); + auto new_i = builder.Build<AddOp>(body_i_argument, one).out(); + + // comput new condition value: new_i <= new_ten + auto new_cond_value = + builder.Build<LessEqualOp>(new_i, body_ten_argument).out(); + + builder.Build<pir::YieldOp>( + std::vector<pir::Value>{new_cond_value, new_i, body_ten_argument}); + + builder.SetInsertionPointAfter(while_op); + + auto kernel_program = PdOpLowerToKernelPass(&program); + + auto place = platform::CPUPlace(); + Scope scope; + InterpreterCore test_core(place, {}, kernel_program->block(), &scope); + + std::stringstream os; + os << reinterpret_cast<NewIRInterpreter*>( + const_cast<InterpreterBaseImpl*>(test_core.Impl())); + std::string out_name = os.str() + "_inner_var_3"; + test_core.SetSkipGcVars({out_name}); + + test_core.Run({}); + + auto out_tensor = + test_core.local_scope() == nullptr + ? scope.FindVar(out_name)->Get<phi::DenseTensor>() + : test_core.local_scope()->FindVar(out_name)->Get<phi::DenseTensor>(); + + bool res0 = out_tensor.data<int>()[0] == 11; + + EXPECT_EQ(res0, true); +} + } // namespace framework } // namespace paddle diff --git a/test/cpp/phi/core/test_ddim.cc b/test/cpp/phi/core/test_ddim.cc old mode 100755 new mode 100644 index 3a8afe131eb4df..a58d86e62aa403 --- a/test/cpp/phi/core/test_ddim.cc +++ b/test/cpp/phi/core/test_ddim.cc @@ -126,7 +126,7 @@ TEST(DDim, Print) { TEST(DDim, Hash) { // hash a DDim - std::size_t h; + std::size_t h = 0; phi::DDim ddim = phi::make_ddim({2, 3, 4}); h = std::hash<phi::DDim>()(ddim); EXPECT_EQ(h, 0xa16fb2b2967ul); diff --git a/test/cpp/pir/cinn/CMakeLists.txt b/test/cpp/pir/cinn/CMakeLists.txt index 6b984f3a03ae91..94716a65447bdb 100644 --- a/test/cpp/pir/cinn/CMakeLists.txt +++ b/test/cpp/pir/cinn/CMakeLists.txt @@ -5,7 +5,6 @@ if(WITH_TESTING AND WITH_CINN) new_ir_compiler_test.cc DEPS new_ir_compiler - convert_to_dialect cinn_runtime_dialect pir phi @@ -13,15 +12,21 @@ if(WITH_TESTING AND WITH_CINN) glog) set_tests_properties(test_new_ir_compiler PROPERTIES LABELS "RUN_TYPE=CINN") + cc_test_old(test_jit_instruction SRCS jit_instruction_test.cc DEPS + interpreter new_ir_compiler) + set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN") + cc_test_old( - test_jit_instruction + ir_op_fusion_test SRCS - jit_instruction_test.cc + ir_op_fusion_test.cc DEPS - interpreter - new_ir_compiler - convert_to_dialect) - set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN") + op_with_group_merge_pass + pd_op_dialect + cinn_op_dialect + pir + gtest + glog) paddle_test(test_group_op SRCS group_op_test.cc DEPS cinn_op_dialect) set_tests_properties(test_group_op PROPERTIES LABELS "RUN_TYPE=CINN") diff --git a/test/cpp/pir/cinn/ir_op_fusion_test.cc b/test/cpp/pir/cinn/ir_op_fusion_test.cc new file mode 100644 index 00000000000000..a392373358b2af --- /dev/null +++ b/test/cpp/pir/cinn/ir_op_fusion_test.cc @@ -0,0 +1,444 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <glog/logging.h> +#include <gtest/gtest.h> +#include <sstream> + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/op_with_group_merge_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/program.h" + +std::vector<pir::OpResult> BuildInput( + ::pir::Builder* builder, + const std::vector<std::vector<int64_t>>& vec_shapes) { + std::vector<pir::OpResult> vec_res; + for (size_t i = 0; i < vec_shapes.size(); ++i) { + auto op = builder->Build<paddle::dialect::FullOp>( + vec_shapes[i], 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); + + vec_res.push_back(op.result(0)); + } + + return vec_res; +} + +TEST(IROpFusionPass, demo) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + auto inputs = BuildInput(&builder_base, {{10, 10}, {10, 10}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + auto add = builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]); + builder.Build<paddle::dialect::ReluOp>(add.result(0)); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + ASSERT_EQ(res.size(), 1u); +} + +TEST(IROpFusionPass, ElementWise_Fusion_0) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}, {h, w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + auto e = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + auto f = builder.Build<paddle::dialect::AddOp>(e, inputs[2]).result(0); + builder.Build<paddle::dialect::AddOp>(f, inputs[2]); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(res.size(), 1u); +} + +// Real test 0 +TEST(IROpFusionPass, Broadcast_Test_0) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{w}, {w}, {h, w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + auto e = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + auto f = + builder.Build<paddle::dialect::AddOp>(inputs[2], inputs[3]).result(0); + std::vector<int64_t> axes{1}; + std::vector<int64_t> out_shape{h, w}; + auto e1 = + builder.Build<cinn::dialect::BroadcastOp>(e, axes, out_shape).result(0); + builder.Build<paddle::dialect::AddOp>(e1, f); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + // ASSERT_EQ(res.size(), 1u); +} + +// Real test 1 +TEST(IROpFusionPass, Broadcast_Test_1) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{w}, {w}, {w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + auto e = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[2], e).result(0); + std::vector<int64_t> axes{1}; + std::vector<int64_t> out_shape{h, w}; + auto e1 = + builder.Build<cinn::dialect::BroadcastOp>(e, axes, out_shape).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[3], e1); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 2u); +} + +// Real test 2 +TEST(IROpFusionPass, Broadcast_Test_2) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{w}, {w}, {w}, {h, w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + auto f = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[2], f).result(0); + std::vector<int64_t> axes{1}; + std::vector<int64_t> out_shape{h, w}; + auto f1 = + builder.Build<cinn::dialect::BroadcastOp>(f, axes, out_shape).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[3], f1); + builder.Build<paddle::dialect::AddOp>(inputs[4], f1); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 2u); +} + +// Real reduce 0 +TEST(IROpFusionPass, reduce_test_0) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{0}; + auto c = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(c, axes, true).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(c, axes, true).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(c, axes, true).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 1u); +} + +// Real reduce 1 +TEST(IROpFusionPass, reduce_test_1) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{0}; + std::vector<int64_t> axes1{1}; + auto c = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(c, axes, true).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(c, axes1, true).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 2u); +} + +// Real reduce 2 +TEST(IROpFusionPass, reduce_test_2) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}, {w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{0}; + std::vector<int64_t> axes1{1}; + auto d = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + auto e = builder.Build<cinn::dialect::ReduceSumOp>(d, axes, false).result(0); + auto f = builder.Build<cinn::dialect::ReduceSumOp>(d, axes1, false).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[2], e).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[2], f).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 2u); +} + +// Real reduce 3 +TEST(IROpFusionPass, reduce_test_3) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}, {w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{0}; + std::vector<int64_t> axes1{1}; + auto e = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + auto f = builder.Build<cinn::dialect::ReduceSumOp>(e, axes, false).result(0); + + builder.Build<paddle::dialect::AddOp>(inputs[2], f).result(0); + + std::vector<int64_t> out_shape{h, w}; + auto f1 = + builder.Build<cinn::dialect::BroadcastOp>(f, axes1, out_shape).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[2], f1).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 1u); +} + +// Real reduce 4 +TEST(IROpFusionPass, reduce_test_4) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}, {w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{0}; + std::vector<int64_t> axes1{1}; + auto e = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + auto f = builder.Build<cinn::dialect::ReduceSumOp>(e, axes, false).result(0); + + builder.Build<paddle::dialect::AddOp>(inputs[2], f).result(0); + + std::vector<int64_t> out_shape{h, w}; + auto f1 = + builder.Build<cinn::dialect::BroadcastOp>(f, axes1, out_shape).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[3], f1).result(0); + auto f2 = + builder.Build<cinn::dialect::BroadcastOp>(f, axes1, out_shape).result(0); + builder.Build<paddle::dialect::AddOp>(inputs[3], f2).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 1u); +} + +// Real reduce 5 +TEST(IROpFusionPass, reduce_test_5) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + int h = 32, w = 32; + auto inputs = BuildInput(&builder_base, {{h, w}, {h, w}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{1}; + + auto c = + builder.Build<paddle::dialect::AddOp>(inputs[0], inputs[1]).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(inputs[0], axes, false).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(inputs[1], axes, false).result(0); + builder.Build<cinn::dialect::ReduceSumOp>(c, axes, false).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 1u); +} + +TEST(IROpFusionPass, layer_norm) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ::pir::Program program_base(ctx); + ::pir::Builder builder_base = ::pir::Builder(ctx, program_base.block()); + + auto inputs = BuildInput(&builder_base, {{128, 128, 768}, {768}, {768}}); + + ::pir::Program program(ctx); + ::pir::Builder builder = ::pir::Builder(ctx, program.block()); + + std::vector<int64_t> axes{-1}; + + auto num = builder + .Build<paddle::dialect::FullOp>(std::vector<int64_t>{1}, + 768.0, + phi::DataType::FLOAT32, + phi::CPUPlace()) + .result(0); + auto eps = builder + .Build<paddle::dialect::FullOp>(std::vector<int64_t>{1}, + 1e-5, + phi::DataType::FLOAT32, + phi::CPUPlace()) + .result(0); + + auto sum = builder.Build<cinn::dialect::ReduceSumOp>(inputs[0], axes, true) + .result(0); + std::vector<int64_t> all_axes{0, 1, 2}; + std::vector<int64_t> out_shape1{128, 128, 1}; + auto num1 = + builder.Build<cinn::dialect::BroadcastOp>(num, all_axes, out_shape1) + .result(0); + auto mean = builder.Build<paddle::dialect::DivideOp>(sum, num1).result(0); + auto power = builder.Build<paddle::dialect::MultiplyOp>(inputs[0], inputs[0]) + .result(0); + auto power_sum = + builder.Build<cinn::dialect::ReduceSumOp>(power, axes, true).result(0); + auto mean2 = + builder.Build<paddle::dialect::DivideOp>(power_sum, num1).result(0); + auto power_mean = + builder.Build<paddle::dialect::MultiplyOp>(mean, mean).result(0); + + auto var = + builder.Build<paddle::dialect::SubtractOp>(mean2, power_mean).result(0); + + std::vector<int64_t> out_shape2{128, 128, 768}; + auto sub = + builder.Build<paddle::dialect::SubtractOp>(inputs[0], mean).result(0); + auto eps1 = + builder.Build<cinn::dialect::BroadcastOp>(eps, all_axes, out_shape2) + .result(0); + auto t1 = builder.Build<paddle::dialect::AddOp>(var, eps1).result(0); + auto t2 = builder.Build<paddle::dialect::SqrtOp>(t1).result(0); + auto t3 = builder.Build<paddle::dialect::DivideOp>(sub, t2).result(0); + auto scale = + builder.Build<cinn::dialect::BroadcastOp>(inputs[1], all_axes, out_shape2) + .result(0); + auto bias = + builder.Build<cinn::dialect::BroadcastOp>(inputs[2], all_axes, out_shape2) + .result(0); + auto t5 = builder.Build<paddle::dialect::MultiplyOp>(t3, scale).result(0); + builder.Build<paddle::dialect::MultiplyOp>(t5, bias).result(0); + + auto res = cinn::dialect::ir::OpFusionPassInternal(program); + + auto new_group = + cinn::dialect::ir::GeneralFusionMergePassInternal(&program, res); + + ASSERT_EQ(new_group.size(), 1u); +} diff --git a/test/cpp/pir/cinn/jit_instruction_test.cc b/test/cpp/pir/cinn/jit_instruction_test.cc index 2996bf17c962a7..5e80cd8021a3fa 100644 --- a/test/cpp/pir/cinn/jit_instruction_test.cc +++ b/test/cpp/pir/cinn/jit_instruction_test.cc @@ -27,11 +27,18 @@ #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/program.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/convert_to_dialect.h" #include "paddle/cinn/hlir/framework/new_ir_compiler.h" #include "paddle/cinn/utils/data_util.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" +#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" +#include "paddle/phi/backends/gpu/gpu_context.h" + +bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; } std::unique_ptr<::pir::Program> BuildProgram() { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); @@ -39,18 +46,29 @@ std::unique_ptr<::pir::Program> BuildProgram() { auto program = std::make_unique<::pir::Program>(ctx); ::pir::Builder builder = ::pir::Builder(ctx, program->block()); - const float value = 2.0; + const float value = 0.5; auto full_op_x = - builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128}, + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{2, 2}, value, phi::DataType::FLOAT32, phi::GPUPlace()); auto full_op_y = - builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64}, + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{2, 2}, + value, + phi::DataType::FLOAT32, + phi::GPUPlace()); + auto full_op_z = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{2, 2}, value, phi::DataType::FLOAT32, phi::GPUPlace()); + + auto sin = builder.Build<paddle::dialect::SinOp>(full_op_x.result(0)); + auto cos = builder.Build<paddle::dialect::CosOp>(full_op_y.result(0)); + auto add = + builder.Build<paddle::dialect::AddOp>(sin.result(0), cos.result(0)); + builder.Build<paddle::dialect::FetchOp>(add.out(), "out", 0); return std::move(program); } @@ -60,43 +78,105 @@ namespace framework { TEST(CinnJitInstruction, Run) { // Step 1: Construct pir::Program std::unique_ptr<::pir::Program> program = BuildProgram(); - EXPECT_EQ(program->block()->size(), 2u); + EXPECT_EQ(program->block()->size(), 7u); // Step 2: Compiler New pir::Program into Runtime Program auto target = cinn::common::DefaultNVGPUTarget(); auto scope = cinn::hlir::framework::BuildScope(target, *program); - ASSERT_EQ(scope->var_names().size(), 2); - cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); - auto runtime_program = ir_compiler.Build(); + std::vector<cinn::hlir::framework::NewIRCompiler*> compiler_list; - // Step 3: Convert into cinn::dialect::RuntimeDialect - std::unique_ptr<::pir::Program> ir_runtime_program = - cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program); + std::set<std::string> checking_cinn_ops = {"pd_op.sin", "pd_op.cos"}; - std::set<std::string> out_names; - for (auto& var_name : scope->var_names()) { - std::string name = {var_name.begin(), var_name.end()}; - out_names.insert(name); + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<cinn::dialect::RuntimeDialect>(); + ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>(); + auto ir_program = std::make_unique<::pir::Program>(ctx); + std::string jit_op_name = cinn::dialect::JitKernelOp::name(); + ::pir::OpInfo op_info = ctx->GetRegisteredOpInfo(jit_op_name); + + std::unordered_map<pir::Value, pir::Value> value_map; + for (auto it = program->block()->begin(); it != program->block()->end(); + ++it) { + if (checking_cinn_ops.count((*it)->name())) { + auto ir_compiler = + new cinn::hlir::framework::NewIRCompiler(*program, target, scope); + + std::vector<::pir::Operation*> ops = {*it}; + auto group = std::make_shared<cinn::hlir::framework::newir::Group>(ops); + auto fn_ptr_res = ir_compiler->BuildCUDAJITInfo({group}); + compiler_list.push_back(ir_compiler); + std::unordered_map<std::string, ::pir::Attribute> op_attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CUDAJITInfoAttribute::get(ctx, fn_ptr_res[0])}, + }; + + auto out_type = (*it)->result(0).type(); + + std::vector<pir::Value> vec_ins; + + for (size_t i = 0; i < (*it)->num_operands(); ++i) { + vec_ins.push_back(value_map.at((*it)->operand_source(i))); + } + + ::pir::Operation* cinn_op = + ::pir::Operation::Create(vec_ins, op_attrs, {out_type}, op_info); + + value_map[(*it)->result(0)] = cinn_op->result(0); + + ir_program->block()->push_back(cinn_op); + } else { + std::vector<pir::Value> vec_ins; + + for (size_t i = 0; i < (*it)->num_operands(); ++i) { + vec_ins.push_back(value_map.at((*it)->operand_source(i))); + } + + auto type1 = (*it)->result(0).type(); + ::pir::OpInfo info1 = ctx->GetRegisteredOpInfo((*it)->name()); + ::pir::Operation* op = ::pir::Operation::Create( + vec_ins, (*it)->attributes(), {type1}, info1); + + ir_program->block()->push_back(op); + + value_map[(*it)->result(0)] = op->result(0); + } } platform::Place place = platform::CUDAPlace(0); + + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(ir_program.get(), place); + Scope exe_scope; - InterpreterCore executor(place, {}, ir_runtime_program->block(), &exe_scope); - executor.SetSkipGcVars(out_names); - executor.Run({}); - - // TODO(Aurelius84): Need to replace check with framework::Scope. - const float value = 2.0; - for (auto& name : out_names) { - std::vector<float> data = - cinn::GetTensorData<float>(scope->GetTensor(name), target); - for (int i = 0; i < data.size(); ++i) { - LOG_FIRST_N(INFO, 3) << "data: " << data[i]; - ASSERT_NEAR(data[i], value, 1e-5); - } + paddle::framework::interpreter::ExecutionConfig exe_conf; + exe_conf.create_local_scope = false; + InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + std::set<std::string> out_names; + out_names.insert("out@fetch"); + auto local_names = exe_scope.LocalVarNames(); + for (size_t i = 0; i < local_names.size(); ++i) { + out_names.insert(local_names[i]); } + + executor.SetSkipGcVars(out_names); + executor.Run({}, true); + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>(); + + bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.35701); + bool res1 = simple_cmp(out_tensor.data<float>()[1], 1.35701); + bool res2 = simple_cmp(out_tensor.data<float>()[2], 1.35701); + bool res3 = simple_cmp(out_tensor.data<float>()[3], 1.35701); + + EXPECT_EQ(res0, true); + EXPECT_EQ(res1, true); + EXPECT_EQ(res2, true); + EXPECT_EQ(res3, true); } } // namespace framework diff --git a/test/cpp/pir/cinn/new_ir_compiler_test.cc b/test/cpp/pir/cinn/new_ir_compiler_test.cc index 4b680b1ac89048..c75df1959ceada 100644 --- a/test/cpp/pir/cinn/new_ir_compiler_test.cc +++ b/test/cpp/pir/cinn/new_ir_compiler_test.cc @@ -22,7 +22,6 @@ #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/convert_to_dialect.h" #include "paddle/cinn/hlir/framework/new_ir_compiler.h" #include "paddle/cinn/utils/data_util.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" @@ -151,33 +150,4 @@ TEST(RuntimeDialect, CompilerAndRun) { cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope); auto runtime_program = ir_compiler.Build(); - - // Step 3: Convert into cinn::dialect::RuntimeDialect - std::shared_ptr<::pir::Program> ir_runtime_program = - cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program); - - // Step 4: Run cinn::dialect::RuntimeDialect - for (auto iter = ir_runtime_program->block()->begin(); - iter != ir_runtime_program->block()->end(); - ++iter) { - auto op = (*iter)->dyn_cast<cinn::dialect::JitKernelOp>(); - auto* instr = op.instruction(); - instr->Run(/*name2podargs=*/nullptr, - false, - /*stream=*/nullptr, - /*use_cache=*/true); - } -#ifdef CINN_WITH_CUDA - CUDA_CALL(cudaDeviceSynchronize()); -#endif - - // Step 5: Check Scope Tensor Value. - for (auto& var_name : scope->var_names()) { - std::string name = {var_name.begin(), var_name.end()}; - std::vector<float> data = - cinn::GetTensorData<float>(scope->GetTensor(name), target); - for (int i = 0; i < 1; ++i) { - LOG_FIRST_N(INFO, 10) << "data: " << data[i]; - } - } } diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc index 609f1f8eb8d2e9..7536ea2014fe0f 100644 --- a/test/cpp/pir/control_flow_dialect/while_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -24,6 +24,9 @@ #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" using namespace paddle::dialect; // NOLINT + +// example for while_op use +// while(i < ten) { i = i + 1;} TEST(while_op_test, base) { pir::IrContext* ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect<pir::ControlFlowDialect>(); @@ -36,24 +39,15 @@ TEST(while_op_test, base) { auto i = builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32) .out(); - auto ten = builder.Build<FullOp>(std::vector<int64_t>{1}, 10, phi::DataType::INT32) .out(); - auto while_op = builder.Build<WhileOp>( - std::vector<pir::Value>{i, ten}, - std::vector<pir::Type>{builder.int32_type(), builder.int32_type()}); + // comput condition value: i < ten + auto cond_value = builder.Build<LessThanOp>(i, ten).out(); - // while(i < ten) - pir::Block* cond_block = while_op.cond_block(); - auto cond_i_argument = cond_block->AddArgument(i.type()); - auto cond_ten_argument = cond_block->AddArgument(ten.type()); - builder.SetInsertionPointToStart(cond_block); - auto cond_value = - builder.Build<LessThanOp>(cond_i_argument, cond_ten_argument).out(); - builder.Build<pir::CondYieldOp>( - cond_value, std::vector<pir::Value>{cond_i_argument, cond_ten_argument}); + auto while_op = + builder.Build<WhileOp>(cond_value, std::vector<pir::Value>{i, ten}); // { i = i + 1} pir::Block* body_block = while_op.body_block(); @@ -64,12 +58,19 @@ TEST(while_op_test, base) { builder.Build<FullOp>(std::vector<int64_t>{1}, 1, phi::DataType::INT32) .out(); auto new_i = builder.Build<AddOp>(body_i_argument, one).out(); + + // comput new condition value: new_i < new_ten + auto new_cond_value = + builder.Build<LessThanOp>(new_i, body_ten_argument).out(); + builder.Build<pir::YieldOp>( - std::vector<pir::Value>{new_i, body_ten_argument}); + std::vector<pir::Value>{new_cond_value, new_i, body_ten_argument}); builder.SetInsertionPointAfter(while_op); std::stringstream ss; program.Print(ss); LOG(INFO) << ss.str(); + + EXPECT_EQ(while_op.cond(), cond_value); } diff --git a/test/cpp/pir/core/CMakeLists.txt b/test/cpp/pir/core/CMakeLists.txt index 0f0ec568bb50aa..ca71cb8fe9eef9 100644 --- a/test/cpp/pir/core/CMakeLists.txt +++ b/test/cpp/pir/core/CMakeLists.txt @@ -65,6 +65,11 @@ file( ${CMAKE_CURRENT_BINARY_DIR}/conditional_block_test.prog EXPECTED_MD5 cf9dc869ca7f69e2d57b38dbf8427134) +file( + DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/while_op_test.prog + ${CMAKE_CURRENT_BINARY_DIR}/while_op_test.prog + EXPECTED_MD5 290164ae52a496332b0be5829fc93bcd) + copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/TestParserText.txt ${CMAKE_CURRENT_BINARY_DIR}/TestParserText.txt) diff --git a/test/cpp/pir/core/TestParserText.txt b/test/cpp/pir/core/TestParserText.txt index 71a6e0425f0c36..9f979c50cc7c32 100644 --- a/test/cpp/pir/core/TestParserText.txt +++ b/test/cpp/pir/core/TestParserText.txt @@ -76,3 +76,23 @@ f16 //CHECK attribute [] //END + +//CHECK type +vec[vec[],vec[]] +//END + +//CHECK attribute +[(Float)inf,(Float)-inf] +//END + +//CHECK attribute +[(Float)-1,(Float)-1.00001,(Double)-1.00001,(Float)-1.1e+30,(Double)1e+200,(Float)0.123456,(Double)0.123456] +//END + +//CHECK type +vec[vec[i8,bf16],vec[]] +//END + +//CHECK type +vec[vec[i8,bf16],vec[],vec[u8]] +//END diff --git a/test/cpp/pir/core/ir_infershape_test.cc b/test/cpp/pir/core/ir_infershape_test.cc index 720d4b238d5ebd..09d3a2fe9b6b17 100644 --- a/test/cpp/pir/core/ir_infershape_test.cc +++ b/test/cpp/pir/core/ir_infershape_test.cc @@ -45,7 +45,7 @@ class OperationTest static const char *name() { return "test.operation2"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; // NOLINT - static void Verify() {} + static void VerifySig() {} static void InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::CreateInferMeta); fn(infer_meta); diff --git a/test/cpp/pir/core/ir_parser_test.cc b/test/cpp/pir/core/ir_parser_test.cc index 7990d26e8afaf1..91a26c6e970fc1 100644 --- a/test/cpp/pir/core/ir_parser_test.cc +++ b/test/cpp/pir/core/ir_parser_test.cc @@ -86,7 +86,7 @@ TestTask* ParserTest::GetTestTask() { std::string test_type_info; while (test_text.peek() != '\n' && test_text.peek() != ' ' && test_text.peek() != EOF) { - test_type_info += test_text.get(); + test_type_info += test_text.get(); // NOLINT } while (test_text.peek() == '\n' || test_text.peek() == ' ') { @@ -95,10 +95,10 @@ TestTask* ParserTest::GetTestTask() { std::string test_info; while (Peek(5) != "//END" && test_text.peek() != EOF) { - test_info += test_text.get(); + test_info += test_text.get(); // NOLINT } - if (Peek(5) != "//END" || test_info.size() == 0) { + if (Peek(5) != "//END" || static_cast<int>(test_info.size()) == 0) { return nullptr; } @@ -175,7 +175,7 @@ std::string ParserTest::Get(const size_t len) { if (test_text.peek() == EOF) { break; } - str += test_text.get(); + str += test_text.get(); // NOLINT } return str; } diff --git a/test/cpp/pir/core/ir_program_test.cc b/test/cpp/pir/core/ir_program_test.cc index 85f608aa117a28..7ae348d004f53e 100644 --- a/test/cpp/pir/core/ir_program_test.cc +++ b/test/cpp/pir/core/ir_program_test.cc @@ -41,14 +41,14 @@ class AddOp : public pir::Op<AddOp> { static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - void Verify(); + void VerifySig(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand, pir::Type sum_type); }; -void AddOp::Verify() { +void AddOp::VerifySig() { if (num_operands() != 2) { throw("The size of inputs must be equal to 2."); } diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index 483299c206129e..ba85e396d41b7c 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -266,3 +266,75 @@ TEST(IrParserTest, StartupProgram) { EXPECT_TRUE(ssp.str() == ss.str()); } + +TEST(OperatorDialectTest, WhileOpProgram) { + auto p = load_from_file("while_op_test.prog"); + EXPECT_EQ(p.Size(), 3u); + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::BuiltinDialect>(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + std::stringstream ss; + program->Print(ss); + + LOG(INFO) << ss.str(); + + EXPECT_EQ(program->block()->size(), 4u); + size_t id = 0; + for (auto &op : *program->block()) { + if (id == 0 || id == 1) { + EXPECT_TRUE(op->isa<paddle::dialect::FullOp>()); + } + if (id == 2) { + EXPECT_TRUE(op->isa<paddle::dialect::LessThanOp>()); + } + if (id == 3) { + EXPECT_TRUE(op->isa<paddle::dialect::WhileOp>()); + EXPECT_EQ(op->num_regions(), 1u); + // body block + pir::Block *body_block = + op->dyn_cast<paddle::dialect::WhileOp>().body_block(); + size_t body_id = 0; + for (auto &op1 : *body_block) { + if (body_id == 0) { + EXPECT_TRUE(op1->isa<paddle::dialect::FullOp>()); + } + if (body_id == 1) { + EXPECT_TRUE(op1->isa<paddle::dialect::ScaleOp>()); + } + if (body_id == 2) { + EXPECT_TRUE(op1->isa<paddle::dialect::LessThanOp>()); + } + if (body_id == 3) { + pir::Block *body_body_block = + op1->dyn_cast<paddle::dialect::WhileOp>().body_block(); + size_t body_body_id = 0; + for (auto &op2 : *body_body_block) { + if (body_body_id == 0) { + EXPECT_TRUE(op2->isa<paddle::dialect::FullOp>()); + } + if (body_body_id == 1) { + EXPECT_TRUE(op2->isa<paddle::dialect::ScaleOp>()); + } + if (body_body_id == 2) { + EXPECT_TRUE(op2->isa<paddle::dialect::LessThanOp>()); + } + if (body_body_id == 3) { + EXPECT_TRUE(op2->isa<pir::YieldOp>()); + } + body_body_id++; + } + } + if (body_id == 4) { + EXPECT_TRUE(op1->isa<paddle::dialect::LessThanOp>()); + } + if (body_id == 5) { + EXPECT_TRUE(op1->isa<pir::YieldOp>()); + } + body_id++; + } + } + id++; + } +} diff --git a/test/cpp/pir/pass/pass_manager_test.cc b/test/cpp/pir/pass/pass_manager_test.cc index e83764226ebd11..03e7d88d484bca 100644 --- a/test/cpp/pir/pass/pass_manager_test.cc +++ b/test/cpp/pir/pass/pass_manager_test.cc @@ -69,14 +69,14 @@ class AddOp : public pir::Op<AddOp> { static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - void Verify(); + void VerifySig(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::OpResult l_operand, pir::OpResult r_operand, pir::Type sum_type); }; -void AddOp::Verify() { +void AddOp::VerifySig() { if (num_operands() != 2) { throw("The size of inputs must be equal to 2."); } diff --git a/test/cpp/pir/pattern_rewrite/CMakeLists.txt b/test/cpp/pir/pattern_rewrite/CMakeLists.txt index 7edd32531be34d..6f92036ecd6944 100644 --- a/test/cpp/pir/pattern_rewrite/CMakeLists.txt +++ b/test/cpp/pir/pattern_rewrite/CMakeLists.txt @@ -8,3 +8,45 @@ endif() cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS ${PATTERN_REWRITE_TEST_DEPS}) + +cc_test_old( + drr_test + SRCS + drr_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) +cc_test_old( + drr_fuse_linear_test + SRCS + drr_fuse_linear_test.cc + DEPS + fused_gemm_epilogue_pass + drr + gtest + pd_op_dialect + pir) +cc_test_old( + drr_same_type_binding_test + SRCS + drr_same_type_binding_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) +cc_test_old( + drr_attention_fuse_test + SRCS + drr_attention_fuse_test.cc + DEPS + drr + gtest + pd_op_dialect + pir) + +set_tests_properties( + pattern_rewrite_test PROPERTIES ENVIRONMENT + "FLAGS_enable_new_ir_in_executor=true") diff --git a/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc new file mode 100644 index 00000000000000..22252e52beb394 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc @@ -0,0 +1,380 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <glog/logging.h> +#include <gtest/gtest.h> + +#include <cstdint> +#include <memory> +#include <vector> + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +class MultiHeadMatmulFusePattern + : public pir::drr::DrrPatternBase<MultiHeadMatmulFusePattern> { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + pir::drr::SourcePattern src = ctx->SourcePattern(); + // The first path to matmul with scale (q). + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = + matmul_1(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_1_in_2")); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("matmul_1_out"), src.Tensor("add_1_in_2")); + const auto &full_int_array_1 = + src.Op("pd_op.full_int_array", + {{"value", src.Attr("full_int_array_1_value")}}); + const auto &reshape_1 = src.Op("pd_op.reshape"); + reshape_1({&src.Tensor("add_1_out"), &full_int_array_1()}, + {&src.Tensor("reshape_1_out"), &src.Tensor("reshape_1_xshape")}); + const auto &transpose_1 = src.Op("pd_op.transpose"); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("reshape_1_out")); + const auto &full_1 = + src.Op("pd_op.full", {{"value", src.Attr("full_1_value")}}); + const auto &scale = src.Op("pd_op.scale"); + src.Tensor("scale_out") = scale(src.Tensor("transpose_1_out"), full_1()); + + // The second path to matmul (k). + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_transpose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_2_in_2")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("matmul_2_out"), src.Tensor("add_2_in_2")); + const auto &full_int_array_2 = src.Op("pd_op.full_int_array"); + const auto &reshape_2 = src.Op("pd_op.reshape"); + reshape_2({&src.Tensor("add_2_out"), &full_int_array_2()}, + {&src.Tensor("reshape_2_out"), &src.Tensor("reshape_2_xshape")}); + const auto &transpose_2 = src.Op("pd_op.transpose"); + src.Tensor("transpose_2_out") = transpose_2(src.Tensor("reshape_2_out")); + + // The third path to matmul (v). + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_1_in_1"), src.Tensor("matmul_3_in_2")); + const auto &add_3 = src.Op("pd_op.add"); + src.Tensor("add_3_out") = + add_3(src.Tensor("matmul_3_out"), src.Tensor("add_3_in_2")); + const auto &full_int_array_3 = src.Op("pd_op.full_int_array"); + const auto &reshape_3 = src.Op("pd_op.reshape"); + reshape_3({&src.Tensor("add_3_out"), &full_int_array_3()}, + {&src.Tensor("reshape_3_out"), &src.Tensor("reshape_3_xshape")}); + const auto &transpose_3 = src.Op("pd_op.transpose"); + src.Tensor("transpose_3_out") = transpose_3(src.Tensor("reshape_3_out")); + + // softmax(qk)v + const auto &matmul_4 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_4_transpose_x")}, + {"transpose_y", src.Attr("matmul_4_transpose_y")}}); + src.Tensor("matmul_4_out") = + matmul_4(src.Tensor("scale_out"), src.Tensor("transpose_2_out")); + const auto &add_4 = src.Op("pd_op.add"); + src.Tensor("add_4_out") = + add_4(src.Tensor("matmul_4_out"), src.Tensor("add_4_in_2")); + const auto &softmax = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_axis")}}); + src.Tensor("softmax_out") = softmax(src.Tensor("add_4_out")); + const auto &matmul_5 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_5_transpose_x")}, + {"transpose_y", src.Attr("matmul_5_transpose_y")}}); + src.Tensor("matmul_5_out") = + matmul_5(src.Tensor("softmax_out"), src.Tensor("transpose_3_out")); + const auto &transpose_4 = src.Op("pd_op.transpose"); + src.Tensor("transpose_4_out") = transpose_4(src.Tensor("matmul_5_out")); + const auto &full_int_array_4 = src.Op("pd_op.full_int_array"); + const auto &reshape_4 = src.Op("pd_op.reshape"); + reshape_4({&src.Tensor("transpose_4_out"), &full_int_array_4()}, + {&src.Tensor("reshape_4_out"), &src.Tensor("reshape_4_xshape")}); + + // + // Constraints. + // + src.RequireNativeCall([](const pir::drr::MatchContext &match_ctx) -> bool { + const auto &softmax_axis = match_ctx.Attr<int>("softmax_axis"); + if (softmax_axis != -1 && softmax_axis != 3) return false; + + bool matmul_1_transpose_x = match_ctx.Attr<bool>("matmul_1_transpose_x"); + bool matmul_1_transpose_y = match_ctx.Attr<bool>("matmul_1_transpose_y"); + if (matmul_1_transpose_x || matmul_1_transpose_y) return false; + + bool matmul_2_transpose_x = match_ctx.Attr<bool>("matmul_2_transpose_x"); + bool matmul_2_transpose_y = match_ctx.Attr<bool>("matmul_2_transpose_y"); + if (matmul_2_transpose_x || matmul_2_transpose_y) return false; + + bool matmul_3_transpose_x = match_ctx.Attr<bool>("matmul_3_transpose_x"); + bool matmul_3_transpose_y = match_ctx.Attr<bool>("matmul_3_transpose_y"); + if (matmul_3_transpose_x || matmul_3_transpose_y) return false; + + bool matmul_4_transpose_x = match_ctx.Attr<bool>("matmul_4_transpose_x"); + bool matmul_4_transpose_y = match_ctx.Attr<bool>("matmul_4_transpose_y"); + if (matmul_4_transpose_x || !matmul_4_transpose_y) return false; + + bool matmul_5_transpose_x = match_ctx.Attr<bool>("matmul_5_transpose_x"); + bool matmul_5_transpose_y = match_ctx.Attr<bool>("matmul_5_transpose_y"); + if (matmul_5_transpose_x || matmul_5_transpose_y) return false; + + return true; + }); + + // + // Result Pattern. + // + pir::drr::ResultPattern res = src.ResultPattern(); + // W combine. + const auto &combine_1 = res.Op("builtin.combine"); + combine_1({&res.Tensor("matmul_1_in_2"), + &res.Tensor("matmul_2_in_2"), + &res.Tensor("matmul_3_in_2")}, + {&res.Tensor("combine_1_out")}); + const auto &concat_axis = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> int { return 0; }); + const auto &concat_1 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); + const auto &reshape_5_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector<int64_t> { + auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); + return {-1, 3, matmul_1_in_2.at(1)}; + }); + const auto &reshape_5 = + res.Op("pd_op.reshape", {{"shape", reshape_5_shape}}); + reshape_5({&res.Tensor("concat_1_out")}, + {&res.Tensor("reshape_5_out"), &res.NoneTensor()}); + + // Bias combine. + const auto &combine_2 = res.Op("builtin.combine"); + combine_2({&res.Tensor("add_1_in_2"), + &res.Tensor("add_2_in_2"), + &res.Tensor("add_3_in_2")}, + {&res.Tensor("combine_2_out")}); + const auto &concat_2 = res.Op("pd_op.concat", {{"axis", concat_axis}}); + res.Tensor("concat_2_out") = concat_2(res.Tensor("combine_2_out")); + const auto &reshape_6_shape = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector<int64_t> { + return {3, -1}; + }); + const auto &reshape_6 = + res.Op("pd_op.reshape", {{"shape", reshape_6_shape}}); + reshape_6({&res.Tensor("concat_2_out")}, + {&res.Tensor("reshape_6_out"), &res.NoneTensor()}); + + const auto &head_number = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> int { + const auto &full_int_array_1_value = + match_ctx.Attr<std::vector<int64_t>>("full_int_array_1_value"); + return full_int_array_1_value.at(2); + }); + const auto &alpha = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + return match_ctx.Attr<float>("full_1_value"); + }); + const auto &multihead_matmul = res.Op( + "pd_op.multihead_matmul", + {{"transpose_q", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"transpose_k", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return true; + })}, + {"transpose_v", res.Attr([](const pir::drr::MatchContext &match_ctx) { + return false; + })}, + {"head_number", head_number}, + {"alpha", alpha}}); + multihead_matmul({&res.Tensor("matmul_1_in_1"), + &res.Tensor("reshape_5_out"), + &res.Tensor("reshape_6_out"), + &res.Tensor("add_4_in_2")}, + {&res.Tensor("reshape_4_out")}); + } +}; + +class AttentionFusePass : public pir::Pass { + public: + AttentionFusePass() : pir::Pass("AttentionFusePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(MultiHeadMatmulFusePattern().Build(context)); + // Add other attention variant fuse pattern. + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +namespace pir { +std::unique_ptr<Pass> CreateAttentionFusePass() { + return std::make_unique<AttentionFusePass>(); +} +} // namespace pir + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp matmul_1_in_1 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 300, 256}, + 0.9, + phi::DataType::FLOAT32, + phi::CPUPlace()); + // The first path to matmul with scale (q). + paddle::dialect::FullOp matmul_1_in_2 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_1 = builder.Build<paddle::dialect::MatmulOp>( + matmul_1_in_1.out(), matmul_1_in_2.out(), false, false); + + paddle::dialect::FullOp add_1_in_2 = builder.Build<paddle::dialect::FullOp>( + std::vector<int64_t>{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::AddOp add_1 = + builder.Build<paddle::dialect::AddOp>(matmul_1.out(), add_1_in_2.out()); + + paddle::dialect::ReshapeOp reshape_1 = + builder.Build<paddle::dialect::ReshapeOp>( + add_1.out(), std::vector<int64_t>{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_1 = + builder.Build<paddle::dialect::TransposeOp>(reshape_1.out(), + std::vector<int>{0, 2, 1, 3}); + + paddle::dialect::ScaleOp scale_op = builder.Build<paddle::dialect::ScaleOp>( + transpose_1.out(), 0.1767766922712326, 0.0, true); + + // The second path to matmul (k). + paddle::dialect::FullOp matmul_2_in_2 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_2 = builder.Build<paddle::dialect::MatmulOp>( + matmul_1_in_1.out(), matmul_2_in_2.out(), false, false); + + paddle::dialect::FullOp add_2_in_2 = builder.Build<paddle::dialect::FullOp>( + std::vector<int64_t>{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + paddle::dialect::AddOp add_op2 = + builder.Build<paddle::dialect::AddOp>(matmul_2.out(), add_2_in_2.out()); + + paddle::dialect::ReshapeOp reshape_2 = + builder.Build<paddle::dialect::ReshapeOp>( + add_op2.out(), std::vector<int64_t>{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_2 = + builder.Build<paddle::dialect::TransposeOp>(reshape_2.out(), + std::vector<int>{0, 2, 1, 3}); + + // The third path to matmul (v). + paddle::dialect::FullOp matmul_3_in_2 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::MatmulOp matmul_3 = builder.Build<paddle::dialect::MatmulOp>( + matmul_1_in_1.out(), matmul_3_in_2.out(), false, false); + + paddle::dialect::FullOp add_3_in_2 = builder.Build<paddle::dialect::FullOp>( + std::vector<int64_t>{256}, 1.5, phi::DataType::FLOAT32, phi::CPUPlace()); + + paddle::dialect::AddOp add_3 = + builder.Build<paddle::dialect::AddOp>(matmul_3.out(), add_3_in_2.out()); + + paddle::dialect::ReshapeOp reshape_3 = + builder.Build<paddle::dialect::ReshapeOp>( + add_3.out(), std::vector<int64_t>{0, 0, 8, 32}); + + paddle::dialect::TransposeOp transpose_3 = + builder.Build<paddle::dialect::TransposeOp>(reshape_3.out(), + std::vector<int>{0, 2, 1, 3}); + + // softmax(qk)v + paddle::dialect::MatmulOp matmul_4 = builder.Build<paddle::dialect::MatmulOp>( + scale_op.out(), transpose_2.out(), false, true); + + paddle::dialect::FullOp add_4_in_2 = builder.Build<paddle::dialect::FullOp>( + std::vector<int64_t>{1, 8, 300, 300}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_4 = + builder.Build<paddle::dialect::AddOp>(matmul_4.out(), add_4_in_2.out()); + + paddle::dialect::SoftmaxOp softmax_op = + builder.Build<paddle::dialect::SoftmaxOp>(add_4.out(), -1); + paddle::dialect::MatmulOp matmul_5 = builder.Build<paddle::dialect::MatmulOp>( + softmax_op.out(), transpose_3.out(), false, false); + + paddle::dialect::TransposeOp transpose_4 = + builder.Build<paddle::dialect::TransposeOp>(matmul_5.out(), + std::vector<int>{0, 2, 1, 3}); + + paddle::dialect::ReshapeOp reshape_4 = + builder.Build<paddle::dialect::ReshapeOp>( + transpose_4.out(), std::vector<int64_t>{0, 0, 256}); + + builder.Build<paddle::dialect::FetchOp>(reshape_4.out(), "out", 0); +} + +TEST(DrrTest, AttentionFuse) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::BuiltinDialect>(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + EXPECT_EQ(program.block()->size(), 33u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateAttentionFusePass()); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 20u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc new file mode 100644 index 00000000000000..bb2e091043d0b6 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_fuse_linear_test.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <glog/logging.h> +#include <gtest/gtest.h> +#include <memory> + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/transforms/fused_gemm_epilogue_pass.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 512, 64}, + 1.5); + // linear 1 + paddle::dialect::FullOp full_weight_op1 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 64}, 1.5); + paddle::dialect::FullOp full_bias_op1 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0); + paddle::dialect::MatmulOp matmul_op1 = + builder.Build<paddle::dialect::MatmulOp>(full_input_op1.out(), + full_weight_op1.out()); + paddle::dialect::AddOp add_op1 = builder.Build<paddle::dialect::AddOp>( + matmul_op1.out(), full_bias_op1.out()); + // linear 2 + paddle::dialect::FullOp full_weight_op2 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128}, + 1.5); + paddle::dialect::FullOp full_bias_op2 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128}, 1.0); + paddle::dialect::MatmulOp matmul_op2 = + builder.Build<paddle::dialect::MatmulOp>(add_op1.out(), + full_weight_op2.out()); + paddle::dialect::AddOp add_op2 = builder.Build<paddle::dialect::AddOp>( + matmul_op2.out(), full_bias_op2.out()); + paddle::dialect::ReluOp relu_op = + builder.Build<paddle::dialect::ReluOp>(add_op2.out()); + // linear 3 + paddle::dialect::FullOp full_weight_op3 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64}, + 1.5); + paddle::dialect::FullOp full_bias_op3 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0); + paddle::dialect::MatmulOp matmul_op3 = + builder.Build<paddle::dialect::MatmulOp>(relu_op.out(), + full_weight_op3.out()); + paddle::dialect::AddOp add_op3 = builder.Build<paddle::dialect::AddOp>( + matmul_op3.out(), full_bias_op3.out()); + paddle::dialect::GeluOp gelu_op1 = + builder.Build<paddle::dialect::GeluOp>(add_op3.out()); + // linear 4 + paddle::dialect::FullOp full_weight_op4 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 64}, 1.5); + paddle::dialect::FullOp full_bias_op4 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64}, 1.0); + paddle::dialect::MatmulOp matmul_op4 = + builder.Build<paddle::dialect::MatmulOp>(gelu_op1.out(), + full_weight_op4.out()); + paddle::dialect::AddOp add_op4 = builder.Build<paddle::dialect::AddOp>( + matmul_op4.out(), full_bias_op4.out()); + paddle::dialect::GeluOp gelu_op2 = + builder.Build<paddle::dialect::GeluOp>(add_op4.out()); + + // backward + paddle::dialect::FullOp full_grad_op = builder.Build<paddle::dialect::FullOp>( + std::vector<int64_t>{1, 512, 64}, 1.0); + + paddle::dialect::GeluGradOp gelu_op2_grad = + builder.Build<paddle::dialect::GeluGradOp>( + add_op4.out(), full_grad_op.out(), false); + // backward linear 4 + paddle::dialect::AddGradOp add_op4_grad = + builder.Build<paddle::dialect::AddGradOp>( + matmul_op4.out(), full_bias_op4.out(), gelu_op2_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op4_grad = + builder.Build<paddle::dialect::MatmulGradOp>( + gelu_op1.out(), full_weight_op4.out(), add_op4_grad.x_grad()); + + paddle::dialect::GeluGradOp gelu_op1_grad = + builder.Build<paddle::dialect::GeluGradOp>( + add_op3.out(), matmul_op4_grad.x_grad(), false); + // backward linear 3 + paddle::dialect::AddGradOp add_op3_grad = + builder.Build<paddle::dialect::AddGradOp>( + matmul_op3.out(), full_bias_op3.out(), gelu_op1_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op3_grad = + builder.Build<paddle::dialect::MatmulGradOp>( + relu_op.out(), full_weight_op3.out(), add_op3_grad.x_grad()); + + paddle::dialect::ReluGradOp relu_op_grad = + builder.Build<paddle::dialect::ReluGradOp>(relu_op.out(), + matmul_op3_grad.x_grad()); + // backward linear 2 + paddle::dialect::AddGradOp add_op2_grad = + builder.Build<paddle::dialect::AddGradOp>( + matmul_op2.out(), full_bias_op2.out(), relu_op_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op2_grad = + builder.Build<paddle::dialect::MatmulGradOp>( + add_op1.out(), full_weight_op2.out(), add_op2_grad.x_grad()); + // backward linear 1 + paddle::dialect::AddGradOp add_op1_grad = + builder.Build<paddle::dialect::AddGradOp>( + matmul_op1.out(), full_bias_op1.out(), matmul_op2_grad.x_grad()); + paddle::dialect::MatmulGradOp matmul_op1_grad = + builder.Build<paddle::dialect::MatmulGradOp>( + full_input_op1.out(), full_weight_op1.out(), add_op1_grad.x_grad()); + + builder.Build<paddle::dialect::FetchOp>(gelu_op2.out(), "out", 0); + builder.Build<paddle::dialect::FetchOp>(matmul_op1_grad.x_grad(), "dx", 1); +} + +TEST(DrrTest, FusedLinear) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::BuiltinDialect>(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 34u); + + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateFusedGemmEpiloguePass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 22u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc new file mode 100644 index 00000000000000..cb4c6e4b0b92f6 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -0,0 +1,332 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <glog/logging.h> +#include <gtest/gtest.h> +#include <memory> + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/transforms/dead_code_elimination_pass.h" + +/* Source pattern: + input1 + / | \ \ \ + / | \ \ \ + full / | | \ \ full_tmp + / | transpos1 | trans2 trans3 \ / | + / | / | | | | \ / | + softmax1 | / | | | | \ / | + \ | / softmax2 | | | add1 | + \ | / \ | \ / | | + layernorm matmul2 matmul1 \ | + / | \ | | \ | + / | \ \ / \ | + / | \ matmul3 add2 + | | | / | \ | + | | | / | \ | + | | | / | \ | + | | | trans4 trans5 trans6 | + | | | | | | | + | | | relu1 softmax3 softmax4 relu2 + | | | | | | | + output0 output1 output2 output3 output4 output5 output6 +*/ + +class SameTypeBindingTestPattern + // This class is for test cases of the same type of OP. + // (without considering the computational logic between OPs, + // only focusing on the process of matching and replacing) + : public pir::drr::DrrPatternBase<SameTypeBindingTestPattern> { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern src = ctx->SourcePattern(); + + // path 1 + const auto &transpose_1 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_1")}}); + src.Tensor("transpose_1_out") = transpose_1(src.Tensor("input_1")); + const auto &softmax_2 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_2_axis")}}); + src.Tensor("softmax_2_out") = softmax_2(src.Tensor("transpose_1_out")); + const auto &matmul_2 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_2_tradnspose_x")}, + {"transpose_y", src.Attr("matmul_2_transpose_y")}}); + src.Tensor("matmul_2_out") = + matmul_2(src.Tensor("softmax_2_out"), src.Tensor("input_1")); + + // path 2 + const auto &full_1 = src.Op("pd_op.full", + {{"shape", src.Attr("shape_1")}, + {"value", src.Attr("value_1")}, + {"dtype", src.Attr("dtype_1")}, + {"place", src.Attr("place_1")}}); + src.Tensor("full_1_out") = full_1(); + const auto &softmax_1 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_1_axis")}}); + src.Tensor("softmax_1_out") = softmax_1(src.Tensor("full_1_out")); + const auto &layernorm_1 = + src.Op("pd_op.layer_norm", + {{"epsilon", src.Attr("layernorm_epsilon")}, + {"begin_norm_axis", src.Attr("layernorm_begin_norm_axis")}}); + layernorm_1({&src.Tensor("transpose_1_out"), + &src.Tensor("full_1_out"), + &src.Tensor("softmax_1_out")}, + {&src.Tensor("output0"), + &src.Tensor("output1"), + &src.Tensor("output2")}); + + // path 3 + const auto &transpose_2 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_2")}}); + const auto &transpose_3 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_3")}}); + const auto &matmul_1 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_1_transpose_x")}, + {"transpose_y", src.Attr("matmul_1_transpose_y")}}); + src.Tensor("matmul_1_out") = matmul_1(transpose_2(src.Tensor("input_1")), + transpose_3(src.Tensor("input_1"))); + const auto &matmul_3 = + src.Op("pd_op.matmul", + {{"transpose_x", src.Attr("matmul_3_transpose_x")}, + {"transpose_y", src.Attr("matmul_3_transpose_y")}}); + src.Tensor("matmul_3_out") = + matmul_3(src.Tensor("matmul_2_out"), src.Tensor("matmul_1_out")); + const auto &transpose_4 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); + const auto &transpose_5 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_5")}}); + const auto &transpose_6 = + src.Op("pd_op.transpose", {{"perm", src.Attr("perm_6")}}); + const auto &relu_1 = src.Op("pd_op.relu"); + const auto &softmax_3 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_3_axis")}}); + const auto &softmax_4 = + src.Op("pd_op.softmax", {{"axis", src.Attr("softmax_4_axis")}}); + src.Tensor("output3") = relu_1(transpose_4(src.Tensor("matmul_3_out"))); + src.Tensor("output4") = softmax_3(transpose_5(src.Tensor("matmul_3_out"))); + src.Tensor("output5") = softmax_4(transpose_6(src.Tensor("matmul_3_out"))); + + // path 4 + const auto &full_tmp = src.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + src.Tensor("full_tmp_out") = full_tmp(); + const auto &add_1 = src.Op("pd_op.add"); + src.Tensor("add_1_out") = + add_1(src.Tensor("input_1"), src.Tensor("full_tmp_out")); + const auto &add_2 = src.Op("pd_op.add"); + src.Tensor("add_2_out") = + add_2(src.Tensor("add_1_out"), src.Tensor("full_tmp_out")); + const auto &relu_2 = src.Op("pd_op.relu"); + src.Tensor("output6") = relu_2(src.Tensor("add_2_out")); + + pir::drr::ResultPattern res = src.ResultPattern(); + const auto &transpose_7 = + res.Op("pd_op.transpose", {{"perm", src.Attr("perm_4")}}); + res.Tensor("output0") = transpose_7(res.Tensor("input_1")); + const auto &transpose_8 = + res.Op("pd_op.transpose", {{"perm", src.Attr("perm_5")}}); + res.Tensor("output1") = transpose_8(res.Tensor("input_1")); + const auto &full_2 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_3 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_4 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_5 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + const auto &full_6 = res.Op("pd_op.full", + {{"shape", src.Attr("shape_tmp")}, + {"value", src.Attr("value_tmp")}, + {"dtype", src.Attr("dtype_tmp")}, + {"place", src.Attr("place_tmp")}}); + res.Tensor("output2") = full_2(); + res.Tensor("output3") = full_3(); + res.Tensor("output4") = full_4(); + res.Tensor("output5") = full_5(); + res.Tensor("output6") = full_6(); + } +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op1 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + // path 1 + paddle::dialect::TransposeOp transpose_op1 = + builder.Build<paddle::dialect::TransposeOp>(full_input_op1.out(), + std::vector<int>{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op2 = + builder.Build<paddle::dialect::SoftmaxOp>(transpose_op1.out(), -1); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build<paddle::dialect::MatmulOp>(softmax_op2.out(), + full_input_op1.out()); + + // path 2 + paddle::dialect::FullOp full_op_scale = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{48}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::SoftmaxOp softmax_op_bias = + builder.Build<paddle::dialect::SoftmaxOp>(full_op_scale.out(), -1); + paddle::dialect::LayerNormOp layernorm_op1 = + builder.Build<paddle::dialect::LayerNormOp>( + transpose_op1.out(), full_op_scale.out(), softmax_op_bias.out()); + + // path 3 + paddle::dialect::TransposeOp transpose_op2 = + builder.Build<paddle::dialect::TransposeOp>(full_input_op1.out(), + std::vector<int>{0, 1, 2}); + + paddle::dialect::TransposeOp transpose_op3 = + builder.Build<paddle::dialect::TransposeOp>(full_input_op1.out(), + std::vector<int>{0, 1, 2}); + + paddle::dialect::MatmulOp matmul_op1 = + builder.Build<paddle::dialect::MatmulOp>(transpose_op2.out(), + transpose_op3.out()); + + paddle::dialect::MatmulOp matmul_op3 = + builder.Build<paddle::dialect::MatmulOp>(matmul_op2.out(), + matmul_op1.out()); + + paddle::dialect::TransposeOp transpose_op4 = + builder.Build<paddle::dialect::TransposeOp>(matmul_op3.out(), + std::vector<int>{0, 1, 2}); + + paddle::dialect::ReluOp relu_op1 = + builder.Build<paddle::dialect::ReluOp>(transpose_op4.out()); + + paddle::dialect::TransposeOp transpose_op5 = + builder.Build<paddle::dialect::TransposeOp>(matmul_op3.out(), + std::vector<int>{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op3 = + builder.Build<paddle::dialect::SoftmaxOp>(transpose_op5.out(), -1); + + paddle::dialect::TransposeOp transpose_op6 = + builder.Build<paddle::dialect::TransposeOp>(matmul_op3.out(), + std::vector<int>{0, 1, 2}); + + paddle::dialect::SoftmaxOp softmax_op4 = + builder.Build<paddle::dialect::SoftmaxOp>(transpose_op6.out(), -1); + + // path 4 + paddle::dialect::FullOp full_input_op2 = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_op1 = builder.Build<paddle::dialect::AddOp>( + full_input_op1.out(), full_input_op2.out()); + + paddle::dialect::AddOp add_op2 = builder.Build<paddle::dialect::AddOp>( + add_op1.out(), full_input_op2.out()); + + paddle::dialect::ReluOp relu_op2 = + builder.Build<paddle::dialect::ReluOp>(add_op2.out()); + + // tail + paddle::dialect::MatmulOp matmul_op4 = + builder.Build<paddle::dialect::MatmulOp>(layernorm_op1.variance(), + layernorm_op1.mean()); + + paddle::dialect::MatmulOp matmul_op5 = + builder.Build<paddle::dialect::MatmulOp>(relu_op1.out(), + softmax_op3.out()); + + paddle::dialect::MatmulOp matmul_op6 = + builder.Build<paddle::dialect::MatmulOp>(softmax_op4.out(), + relu_op2.out()); + + builder.Build<paddle::dialect::FetchOp>(matmul_op4.out(), "out1", 0); + builder.Build<paddle::dialect::FetchOp>(matmul_op5.out(), "out2", 1); + builder.Build<paddle::dialect::FetchOp>(matmul_op6.out(), "out3", 2); +} + +class DrrPatternRewritePass : public pir::Pass { + public: + DrrPatternRewritePass() : pir::Pass("DrrPatternRewritePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(SameTypeBindingTestPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +TEST(DrrTest, drr_demo) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::BuiltinDialect>(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 27u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique<DrrPatternRewritePass>()); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 13u); +} diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc new file mode 100644 index 00000000000000..f607fa5a083260 --- /dev/null +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -0,0 +1,232 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <glog/logging.h> +#include <gtest/gtest.h> +#include <memory> + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/transforms/dead_code_elimination_pass.h" + +class RemoveRedundentReshapePattern + : public pir::drr::DrrPatternBase<RemoveRedundentReshapePattern> { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source patterns + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &reshape1 = pat.Op("pd_op.reshape"); + const auto &reshape2 = pat.Op("pd_op.reshape"); + + reshape1({&pat.Tensor("arg0"), &pat.Tensor("shape0")}, + {&pat.Tensor("out1"), &pat.Tensor("xshape_0")}); + reshape2({&pat.Tensor("out1"), &pat.Tensor("shape1")}, + {&pat.Tensor("ret"), &pat.Tensor("xshape_1")}); + + // Result patterns + pir::drr::ResultPattern res = pat.ResultPattern(); + res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, + {&res.Tensor("ret"), &res.Tensor("xshape_1")}); + } +}; + +class FoldExpandToConstantPattern + : public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + // Source Pattern + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full1 = pat.Op("pd_op.full", + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &full_int_array1 = + pat.Op("pd_op.full_int_array", + {{"value", pat.Attr("expand_shape_value")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); + const auto &expand = pat.Op("pd_op.expand"); + pat.Tensor("ret") = expand(full1(), full_int_array1()); + + // Result patterns + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &full2 = res.Op("pd_op.full", + {{"shape", pat.Attr("expand_shape_value")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + res.Tensor("ret") = full2(); + } +}; + +class RemoveRedundentTransposePattern + : public pir::drr::DrrPatternBase<RemoveRedundentTransposePattern> { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &transpose1 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); + const auto &transpose2 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_2")}}); + + pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector<int> { + const auto &perm1 = match_ctx.Attr<std::vector<int>>("perm_1"); + const auto &perm2 = match_ctx.Attr<std::vector<int>>("perm_2"); + std::vector<int> new_perm; + for (int v : perm2) { + new_perm.emplace_back(perm1[v]); + } + return new_perm; + }); + const auto &tranpose_continuous = + res.Op("pd_op.transpose", {{"perm", new_perm_attr}}); + + res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); + } +}; + +class RemoveRedundentCastPattern + : public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> { + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("tmp") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); + pat.Tensor("ret") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); + auto res = pat.ResultPattern(); + res.Tensor("ret") = res.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); + } +}; + +class RemoveUselessCastPattern + : public pir::drr::DrrPatternBase<RemoveUselessCastPattern> { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); + pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); + auto res = pat.ResultPattern(); + res.Tensor("ret").Assign(res.Tensor("arg0")); + } +}; + +void BuildProgram(pir::Builder &builder) { // NOLINT + paddle::dialect::FullOp full_input_op = + builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{4, 3, 16}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::FullIntArrayOp full_int_array_op = + builder.Build<paddle::dialect::FullIntArrayOp>( + std::vector<int64_t>{4, 3, 16, 16}, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::ExpandOp expand_op = + builder.Build<paddle::dialect::ExpandOp>(full_input_op.out(), + full_int_array_op.out()); + + paddle::dialect::ReshapeOp reshape_op1 = + builder.Build<paddle::dialect::ReshapeOp>( + expand_op.out(), std::vector<int64_t>{16, 3, 4, 16}); + + paddle::dialect::ReshapeOp reshape_op2 = + builder.Build<paddle::dialect::ReshapeOp>( + reshape_op1.out(), std::vector<int64_t>{16, 3, 4, 16}); + + paddle::dialect::ReluOp relu_op = + builder.Build<paddle::dialect::ReluOp>(reshape_op2.out()); + + paddle::dialect::CastOp cast_op1 = builder.Build<paddle::dialect::CastOp>( + relu_op.out(), phi::DataType::FLOAT64); + + paddle::dialect::CastOp cast_op2 = builder.Build<paddle::dialect::CastOp>( + cast_op1.out(), phi::DataType::FLOAT32); + + paddle::dialect::TransposeOp transpose_op1 = + builder.Build<paddle::dialect::TransposeOp>(cast_op2.out(), + std::vector<int>{0, 2, 1, 3}); + + paddle::dialect::TransposeOp transpose_op2 = + builder.Build<paddle::dialect::TransposeOp>(transpose_op1.out(), + std::vector<int>{1, 0, 2, 3}); + + paddle::dialect::ReluOp relu_op_second = + builder.Build<paddle::dialect::ReluOp>(transpose_op2.out()); + + builder.Build<paddle::dialect::FetchOp>(relu_op_second.out(), "out", 0); +} + +class DrrPatternRewritePass : public pir::Pass { + public: + DrrPatternRewritePass() : pir::Pass("DrrPatternRewritePass", 1) {} + + bool Initialize(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(RemoveRedundentReshapePattern().Build(context)); + ps.Add(RemoveRedundentTransposePattern().Build(context)); + ps.Add(RemoveRedundentCastPattern().Build(context)); + ps.Add(RemoveUselessCastPattern().Build(context)); + ps.Add(FoldExpandToConstantPattern().Build(context)); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation *op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + } + + bool CanApplyOn(pir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +TEST(DrrTest, drr_demo) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::BuiltinDialect>(); + pir::Program program(ctx); + pir::Builder builder = pir::Builder(ctx, program.block()); + BuildProgram(builder); + + EXPECT_EQ(program.block()->size(), 14u); + + pir::PassManager pm(ctx); + pm.AddPass(std::make_unique<DrrPatternRewritePass>()); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + // pm.EnablePassTiming(); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 7u); +} diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index adfe431a6be2ba..1499ba161bb09d 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -79,11 +79,11 @@ class Operation1 : public pir::Op<Operation1> { static const char *name() { return "test.Operation1"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; // NOLINT - void Verify(); + void VerifySig(); static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; -void Operation1::Verify() { +void Operation1::VerifySig() { auto &attributes = this->attributes(); if (attributes.count("op2_attr1") == 0 || (!attributes.at("op2_attr1").isa<pir::StrAttribute>())) { @@ -390,7 +390,7 @@ class Conv2dFusionOpTest : public pir::Op<Conv2dFusionOpTest, pir::OpResult bias_, pir::OpResult residual_, pir::AttributeMap attributes); - void Verify(); + void VerifySig(); pir::Value input() { return operand_source(0); } pir::Value filter() { return operand_source(1); } pir::Value bias() { return operand_source(2); } @@ -767,7 +767,7 @@ void Conv2dFusionOpTest::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void Conv2dFusionOpTest::Verify() { +void Conv2dFusionOpTest::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: Conv2dFusionOp."; VLOG(4) << "Verifying inputs:"; @@ -1111,9 +1111,12 @@ void BuildProgram(pir::Builder &builder) { // NOLINT // TODO(wilber): Add a normal test. TEST(pattern_rewrite, Patterns) { pir::IrContext *ctx = pir::IrContext::Instance(); + + ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + ctx->GetOrRegisterDialect<pir::BuiltinDialect>(); auto *test_dialect = ctx->GetOrRegisterDialect<Conv2dFusionTestDialect>(); test_dialect->RegisterOp<paddle::dialect::Conv2dFusionOpTest>(); - ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>(); + pir::Program program(ctx); pir::Builder builder = pir::Builder(ctx, program.block()); BuildProgram(builder); @@ -1122,7 +1125,7 @@ TEST(pattern_rewrite, Patterns) { pir::PassManager pm(ctx); pm.AddPass(std::make_unique<TestPass>()); - // pm.AddPass(ir::CreateConstantFoldingPass()); + // pm.AddPass(pir::CreateConstantFoldingPass()); pm.AddPass(pir::CreateDeadCodeEliminationPass()); pm.AddPass(pir::CreateReorderBlockOpsPass()); pm.EnablePassTiming(); diff --git a/test/cpp/pir/tools/test_op.cc b/test/cpp/pir/tools/test_op.cc index 6041efec0e652a..d8ecbb3a2af385 100644 --- a/test/cpp/pir/tools/test_op.cc +++ b/test/cpp/pir/tools/test_op.cc @@ -29,7 +29,7 @@ void BranchOp::Build(pir::Builder &builder, // NOLINT argument.AddSuccessor(target); } -void BranchOp::Verify() const { +void BranchOp::VerifySig() const { IR_ENFORCE((*this)->num_successors() == 1u, "successors number must equal to 1."); IR_ENFORCE((*this)->successor(0), "successor[0] can't be nullptr"); @@ -45,7 +45,7 @@ void Operation1::Build(pir::Builder &builder, // NOLINT argument.AddOutput(builder.float32_type()); argument.AddAttributes(attributes); } -void Operation1::Verify() const { +void Operation1::VerifySig() const { auto &attributes = this->attributes(); if (attributes.count("op1_attr1") == 0 || !attributes.at("op1_attr1").isa<pir::StrAttribute>()) { diff --git a/test/cpp/pir/tools/test_op.h b/test/cpp/pir/tools/test_op.h index 98f01db37614dc..175a9268390e94 100644 --- a/test/cpp/pir/tools/test_op.h +++ b/test/cpp/pir/tools/test_op.h @@ -34,7 +34,7 @@ class RegionOp : public pir::Op<RegionOp, OneRegionTrait> { static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument); // NOLINT - void Verify() const {} + void VerifySig() const {} }; /// @@ -50,7 +50,7 @@ class BranchOp : public pir::Op<BranchOp> { pir::OperationArgument &argument, // NOLINT const std::vector<pir::OpResult> &target_operands, pir::Block *target); - void Verify() const; + void VerifySig() const; }; // Define case op1. @@ -62,7 +62,7 @@ class Operation1 : public pir::Op<Operation1> { static const char *attributes_name[attributes_num]; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument); // NOLINT - void Verify() const; + void VerifySig() const; }; // Define op2. @@ -75,7 +75,7 @@ class Operation2 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; @@ -98,7 +98,7 @@ class TraitExampleOp pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsShapeTraitOp1. @@ -111,7 +111,7 @@ class SameOperandsShapeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsShapeTraitOp2. @@ -127,7 +127,7 @@ class SameOperandsShapeTraitOp2 pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultShapeTraitOp1. @@ -143,7 +143,7 @@ class SameOperandsAndResultShapeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultShapeTraitOp2. @@ -161,7 +161,7 @@ class SameOperandsAndResultShapeTraitOp2 pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultShapeTraitOp3. @@ -180,7 +180,7 @@ class SameOperandsAndResultShapeTraitOp3 pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsElementTypeTraitOp1. @@ -194,7 +194,7 @@ class SameOperandsElementTypeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsElementTypeTraitOp2. @@ -211,7 +211,7 @@ class SameOperandsElementTypeTraitOp2 pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultElementTypeTraitOp1. @@ -227,7 +227,7 @@ class SameOperandsAndResultElementTypeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultElementTypeTraitOp2. @@ -245,7 +245,7 @@ class SameOperandsAndResultElementTypeTraitOp2 pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultElementTypeTraitOp3. @@ -265,7 +265,7 @@ class SameOperandsAndResultElementTypeTraitOp3 pir::Value r_operand, pir::Type out_type1, pir::Type out_type2); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultTypeTraitOp1. @@ -279,7 +279,7 @@ class SameOperandsAndResultTypeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultTypeTraitOp2. @@ -295,7 +295,7 @@ class SameOperandsAndResultTypeTraitOp2 pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultTypeTraitOp3. @@ -315,7 +315,7 @@ class SameOperandsAndResultTypeTraitOp3 pir::Type out_type1, pir::Type out_type2); - void Verify() const {} + void VerifySig() const {} }; } // namespace test diff --git a/test/cpp/prim/CMakeLists.txt b/test/cpp/prim/CMakeLists.txt index 6499c2fae6c6e9..f4f3c1fe778f60 100644 --- a/test/cpp/prim/CMakeLists.txt +++ b/test/cpp/prim/CMakeLists.txt @@ -17,27 +17,7 @@ set(prim_generated_deps final_dygraph_function final_dygraph_node if(WITH_CINN) set(CINN_DEPS cinn_compiler) endif() -cc_test_old( - test_comp_static - SRCS - test_static_prim.cc - DEPS - fleet_executor - static_utils - static_prim_api - generated_op - prim_utils - operator - elementwise_mul_op - elementwise_sub_op - fill_constant_op - activation_op - phi - static_global_utils - static_tensor_operants - generated_static_op - ${CINN_DEPS} - python) +paddle_test(test_comp_static SRCS test_static_prim.cc) if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(WITH_CINN) diff --git a/test/cpp/prim/test_static_prim.cc b/test/cpp/prim/test_static_prim.cc index d4f5dcb8998ae7..8fd7d79bacbc37 100644 --- a/test/cpp/prim/test_static_prim.cc +++ b/test/cpp/prim/test_static_prim.cc @@ -31,46 +31,6 @@ PD_DECLARE_bool(prim_enabled); PHI_DECLARE_string(tensor_operants_mode); -PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(pow, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(subtract, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(multiply, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(concat, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than_raw, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(not_equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_equal, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_than, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_and, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_or, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_xor, CPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_not, CPU, ALL_LAYOUT); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_DECLARE_KERNEL(full, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(tanh_grad, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(pow, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(scale, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(subtract, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(multiply, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(concat, GPU, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(less_than_raw, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(not_equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_equal, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(greater_than, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_and, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_or, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_xor, KPS, ALL_LAYOUT); -PD_DECLARE_KERNEL(bitwise_not, KPS, ALL_LAYOUT); -#endif namespace paddle { namespace prim { @@ -569,20 +529,3 @@ TEST(StaticPrim, TestFlags) { } // namespace prim } // namespace paddle -USE_OP_ITSELF(fill_constant); -USE_OP_ITSELF(tanh); -USE_OP_ITSELF(tanh_grad); -USE_OP_ITSELF(elementwise_mul); -USE_OP_ITSELF(elementwise_sub); -USE_OP_ITSELF(elementwise_pow); -USE_OP_ITSELF(scale); -USE_OP_ITSELF(less_equal); -USE_OP_ITSELF(less_than); -USE_OP_ITSELF(equal); -USE_OP_ITSELF(not_equal); -USE_OP_ITSELF(greater_equal); -USE_OP_ITSELF(greater_than); -USE_OP_ITSELF(bitwise_xor); -USE_OP_ITSELF(bitwise_and); -USE_OP_ITSELF(bitwise_not); -USE_OP_ITSELF(bitwise_or); diff --git a/test/distributed_passes/auto_parallel_pass_test_base.py b/test/distributed_passes/auto_parallel_pass_test_base.py index 69c2d051c7db37..90173e43de5722 100644 --- a/test/distributed_passes/auto_parallel_pass_test_base.py +++ b/test/distributed_passes/auto_parallel_pass_test_base.py @@ -37,6 +37,7 @@ class AutoPallelPassTestBase(DistPassTestBase): def setUp(self): paddle.enable_static() seed = int(os.environ.get('SEED', -1)) + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" if seed <= 0: seed = np.random.randint(low=1, high=1000000, size=[1])[0] os.environ['SEED'] = str(seed) diff --git a/test/distributed_passes/dist_pass_test_base.py b/test/distributed_passes/dist_pass_test_base.py index 72bc7ca78d9de2..945f6f29eeb434 100644 --- a/test/distributed_passes/dist_pass_test_base.py +++ b/test/distributed_passes/dist_pass_test_base.py @@ -64,6 +64,7 @@ def setUp(self): if paddle.is_compiled_with_cuda(): paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" seed = int(os.environ.get('SEED', -1)) if seed <= 0: seed = np.random.randint(low=1, high=1000000, size=[1])[0] diff --git a/test/dygraph_to_static/CMakeLists.txt b/test/dygraph_to_static/CMakeLists.txt index 4231938cf1ee62..1beadd642a66e0 100644 --- a/test/dygraph_to_static/CMakeLists.txt +++ b/test/dygraph_to_static/CMakeLists.txt @@ -3,34 +3,9 @@ file( RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") +set(SOT_ENVS SOT_LOG_LEVEL=0 COST_MODEL=False MIN_GRAPH_SIZE=0 STRICT_MODE=0) set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) -set(DY2ST_EAGER_TEST_ENVS ${GC_ENVS}) -set(TEST_EAGER_OPS - test_bmn - test_break_continue - test_ifelse - test_loop - test_mnist_amp - test_mnist_pure_fp16 - test_mobile_net - test_program_translator - test_ptb_lm - test_reinforcement_learning - test_resnet - test_resnet_amp - test_resnet_pure_fp16 - test_se_resnet - test_sentiment - test_seq2seq - test_tsm - test_word2vec - test_yolov3 - test_bert - test_cycle_gan - test_lstm - test_simnet - test_transformer) list(REMOVE_ITEM TEST_OPS test_lac) # NOTE(Aurelius84): In case of Windows CI, if open ON_INFER, RWLOCK of Scope # will be removed and will cause some random failed in multi-thread. @@ -52,12 +27,7 @@ if(NOT WITH_GPU) endif() foreach(TEST_OP ${TEST_OPS}) - list(FIND TEST_EAGER_OPS ${TEST_OP} WAS_FOUND) - if(NOT WAS_FOUND EQUAL -1) - py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${DY2ST_EAGER_TEST_ENVS}) - else() - py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) - endif() + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS} ${SOT_ENVS}) endforeach() set_tests_properties(test_se_resnet PROPERTIES TIMEOUT 900) @@ -67,10 +37,11 @@ set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 120) set_tests_properties(test_seq2seq PROPERTIES TIMEOUT 150) set_tests_properties(test_cycle_gan PROPERTIES TIMEOUT 150) set_tests_properties(test_bert PROPERTIES TIMEOUT 180) -set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 120) +set_tests_properties(test_basic_api_transformation PROPERTIES TIMEOUT 240) set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120) set_tests_properties(test_transformer PROPERTIES TIMEOUT 200) -set_tests_properties(test_bmn PROPERTIES TIMEOUT 120) +set_tests_properties(test_bmn PROPERTIES TIMEOUT 300) +set_tests_properties(test_bert PROPERTIES TIMEOUT 240) #set_tests_properties(test_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120) diff --git a/test/dygraph_to_static/dygraph_to_static_util.py b/test/dygraph_to_static/dygraph_to_static_util.py index 3202621228710c..9a5b9bf22d92a4 100644 --- a/test/dygraph_to_static/dygraph_to_static_util.py +++ b/test/dygraph_to_static/dygraph_to_static_util.py @@ -49,7 +49,8 @@ def to_sot(func): """ convert run fall_back to ast """ - enable_sot = os.environ.get("ENABLE_SOT", "False") == "True" + # TODO(SigureMo): ENABLE_SOT should always be True, remove this + enable_sot = os.environ.get("ENABLE_SOT", "True") == "True" def impl(*args, **kwargs): if enable_sot: diff --git a/test/dygraph_to_static/dygraph_to_static_utils_new.py b/test/dygraph_to_static/dygraph_to_static_utils_new.py new file mode 100644 index 00000000000000..de74552e3248d1 --- /dev/null +++ b/test/dygraph_to_static/dygraph_to_static_utils_new.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +import logging +import os +import unittest +from enum import Flag, auto +from functools import wraps + +import numpy as np + +from paddle import set_flags, static +from paddle.base import core + +""" +# Usage: +class MyTest(Dy2StTestBase): + @set_to_static_mode( + ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST + ) + @set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR) + def test_case1(self): + raise ValueError("MyTest 1") + + def test_case2(self): + raise ValueError("MyTest 2") + + +class MyTest2(MyTest): + def test_case1(self): + raise ValueError("MyTest2 1") +""" + +logger = logging.getLogger("Dygraph to static utils") +logger.setLevel(logging.WARNING) + + +class ToStaticMode(Flag): + LEGACY_AST = auto() + PIR_AST = auto() + SOT = auto() + + def lower_case_name(self): + return self.name.lower() + + +class IrMode(Flag): + LEGACY_PROGRAM = auto() + PIR = auto() + + def lower_case_name(self): + return self.name.lower() + + +DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT +DEFAULT_IR_MODE = IrMode.LEGACY_PROGRAM + + +def in_sot_mode(): + return os.getenv("ENABLE_FALL_BACK", "False") == "True" + + +@contextlib.contextmanager +def enable_fallback_guard(enable): + flag = os.environ.get("ENABLE_FALL_BACK", None) + os.environ["ENABLE_FALL_BACK"] = enable + yield + if flag is not None: + os.environ["ENABLE_FALL_BACK"] = flag + else: + del os.environ["ENABLE_FALL_BACK"] + + +def to_legacy_ast_test(fn): + """ + convert run fall_back to ast + """ + + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[AST] running AST") + with enable_fallback_guard("False"): + fn(*args, **kwargs) + + return impl + + +def to_sot_test(fn): + """ + convert run fall_back to ast + """ + + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[SOT] running SOT") + with enable_fallback_guard("True"): + fn(*args, **kwargs) + + return impl + + +def to_pir_ast_test(fn): + raise TypeError("Don't enable PIR AST mode now!") + + +def to_legacy_program_test(fn): + def impl(*args, **kwargs): + logger.info("[Program] running legacy program") + return fn(*args, **kwargs) + + return impl + + +def to_pir_test(fn): + @wraps(fn) + def impl(*args, **kwargs): + logger.info("[PIR] running pir") + ir_outs = None + if os.environ.get('FLAGS_use_stride_kernel', False): + return + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + try: + new_ir_flag = 'FLAGS_enable_new_ir_in_executor' + os.environ[new_ir_flag] = 'True' + set_flags({new_ir_flag: True}) + ir_outs = fn(*args, **kwargs) + finally: + del os.environ[new_ir_flag] + set_flags({new_ir_flag: False}) + return ir_outs + + return impl + + +# Metaclass and BaseClass +class Dy2StTestMeta(type): + TO_STATIC_HANDLER_MAP = { + ToStaticMode.SOT: to_sot_test, + ToStaticMode.LEGACY_AST: to_legacy_ast_test, + ToStaticMode.PIR_AST: to_pir_ast_test, + } + + IR_HANDLER_MAP = { + IrMode.LEGACY_PROGRAM: to_legacy_program_test, + IrMode.PIR: to_pir_test, + } + + def __new__(cls, name, bases, attrs): + new_attrs = {} + original_test_cases = { + key: value + for key, value in attrs.items() + if key.startswith("test") and inspect.isfunction(value) + } + logger.info(f"[creating {name}]") + new_attrs.update( + { + key: value + for key, value in attrs.items() + if key not in original_test_cases + } + ) + for fn_name, fn in original_test_cases.items(): + logger.info(f"Generating {fn_name}") + # Disable inherited test cases + for base in bases: + for attr in dir(base): + if attr.startswith(fn_name): + new_attrs[attr] = None + fn_to_static_modes = getattr( + fn, "to_static_mode", DEFAULT_TO_STATIC_MODE + ) + fn_ir_modes = getattr(fn, "ir_mode", DEFAULT_IR_MODE) + fn_disabled_test_cases = getattr(fn, "disabled_test_cases", []) + logger.info(f"fn_to_static_modes: {fn_to_static_modes}") + logger.info(f"fn_ir_modes: {fn_ir_modes}") + logger.info(f"fn_disabled_test_cases: {fn_disabled_test_cases}") + # Get all valid test cases with to_static_mode and ir_mode + to_static_with_ir_modes = [ + (to_static_mode, ir_mode) + for to_static_mode in ToStaticMode + for ir_mode in IrMode + if to_static_mode & fn_to_static_modes and ir_mode & fn_ir_modes + ] + # Filter out disabled test cases and test cases already in compare groups + to_static_with_ir_modes = list( + filter( + lambda flags: (flags not in fn_disabled_test_cases), + to_static_with_ir_modes, + ) + ) + # Generate all test cases + for to_static_mode, ir_mode in to_static_with_ir_modes: + if ( + to_static_mode == ToStaticMode.PIR_AST + and ir_mode == IrMode.LEGACY_PROGRAM + ): + # PIR with LEGACY_PROGRAM is not a valid combination + continue + new_attrs[ + Dy2StTestMeta.test_case_name( + fn_name, to_static_mode, ir_mode + ) + ] = Dy2StTestMeta.convert_test_case(fn, to_static_mode, ir_mode) + return type.__new__(cls, name, bases, new_attrs) + + @staticmethod + def test_case_name(original_name: str, to_static_mode, ir_mode): + return f"{original_name}__{to_static_mode.lower_case_name()}_{ir_mode.lower_case_name()}" + + @staticmethod + def convert_test_case(fn, to_static_mode, ir_mode): + fn = Dy2StTestMeta.IR_HANDLER_MAP[ir_mode](fn) + fn = Dy2StTestMeta.TO_STATIC_HANDLER_MAP[to_static_mode](fn) + return fn + + +class Dy2StTestBase(unittest.TestCase, metaclass=Dy2StTestMeta): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +# Base decorators +def set_to_static_mode(mode: ToStaticMode): + def decorator(fn): + fn.to_static_mode = mode + return fn + + return decorator + + +def set_ir_mode(mode: IrMode): + def decorator(fn): + fn.ir_mode = mode + return fn + + return decorator + + +def disable_test_case(flags): + def decorator(fn): + disabled_test_cases = getattr(fn, "disabled_test_cases", []) + disabled_test_cases.append(flags) + fn.disabled_test_cases = disabled_test_cases + return fn + + return decorator + + +# Suger decorators +# These decorators can be simply composed by base decorators +def ast_only_test(fn): + fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn) + return fn + + +def sot_only_test(fn): + fn = set_to_static_mode(ToStaticMode.SOT)(fn) + return fn + + +def test_with_new_ir(fn): + fn = set_ir_mode(IrMode.PIR)(fn) + return fn + + +def _test_and_compare_with_new_ir(fn): + @wraps(fn) + def impl(*args, **kwargs): + outs = fn(*args, **kwargs) + if core._is_bwd_prim_enabled() or core._is_fwd_prim_enabled(): + return outs + ir_outs = to_pir_test(fn)(*args, **kwargs) + np.testing.assert_equal( + outs, + ir_outs, + err_msg=f'Dy2St Unittest Check ({fn.__name__}) has diff \n' + + f'Expect {outs}\n' + + f'But Got {ir_outs}', + ) + return outs + + return impl + + +def test_and_compare_with_new_ir(need_check_output: bool = True): + def decorator(fn): + fn = set_ir_mode(IrMode.LEGACY_PROGRAM | IrMode.PIR)(fn) + if need_check_output: + logger.info(f"[need_check_output] {fn.__name__}") + fn = _test_and_compare_with_new_ir(fn) + return fn + + return decorator + + +# For debug +def show_all_test_cases(test_class): + logger.info(f"[showing {test_class.__name__}]") + for attr in dir(test_class): + if attr.startswith("test"): + fn = getattr(test_class, attr) + logger.info(f"{attr}: {fn}") diff --git a/test/dygraph_to_static/test_assert.py b/test/dygraph_to_static/test_assert.py index dc01413d0c8bec..210e904454fd93 100644 --- a/test/dygraph_to_static/test_assert.py +++ b/test/dygraph_to_static/test_assert.py @@ -15,7 +15,11 @@ import unittest import numpy -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -33,7 +37,8 @@ def dyfunc_assert_non_variable(x=True): assert x -class TestAssertVariable(unittest.TestCase): +# @dy2static_unittest +class TestAssertVariable(Dy2StTestBase): def _run(self, func, x, with_exception, to_static): paddle.jit.enable_to_static(to_static) if with_exception: @@ -49,6 +54,7 @@ def _run_dy_static(self, func, x, with_exception): self._run(func, x, with_exception, False) @test_and_compare_with_new_ir(False) + @ast_only_test def test_non_variable(self): self._run_dy_static( dyfunc_assert_non_variable, x=False, with_exception=True @@ -58,6 +64,7 @@ def test_non_variable(self): ) @test_and_compare_with_new_ir(False) + @ast_only_test def test_bool_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([False]), with_exception=True @@ -67,6 +74,7 @@ def test_bool_variable(self): ) @test_and_compare_with_new_ir(False) + @ast_only_test def test_int_variable(self): self._run_dy_static( dyfunc_assert_variable, x=numpy.array([0]), with_exception=True diff --git a/test/dygraph_to_static/test_ast_util.py b/test/dygraph_to_static/test_ast_util.py index 52920d81433c69..c2468765e34387 100644 --- a/test/dygraph_to_static/test_ast_util.py +++ b/test/dygraph_to_static/test_ast_util.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) from ifelse_simple_func import ( dyfunc_with_if_else, dyfunc_with_if_else2, @@ -31,7 +35,8 @@ from paddle.utils import gast -class TestAST2Func(unittest.TestCase): +# @dy2static_unittest +class TestAST2Func(Dy2StTestBase): """ TestCase for the transformation from ast.AST into python callable function. """ @@ -43,6 +48,7 @@ def _ast2func(self, func): transformed_func, _ = ast_to_func(ast_root, func) return transformed_func + @ast_only_test def test_ast2func(self): def func(x, y): return x + y @@ -50,6 +56,7 @@ def func(x, y): x, y = 10, 20 self.assertEqual(func(x, y), self._ast2func(func)(x, y)) + @ast_only_test def test_ast2func_dygraph(self): paddle.disable_static() funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else] @@ -62,6 +69,7 @@ def test_ast2func_dygraph(self): self.assertTrue((true_ret == test_ret).all()) @test_and_compare_with_new_ir(False) + @ast_only_test def test_ast2func_static(self): paddle.enable_static() @@ -80,6 +88,7 @@ def func(x): ret = exe.run(main_program, fetch_list=[true_ret, test_ret]) self.assertTrue((ret[0] == ret[1]).all()) + @ast_only_test def test_ast2func_error(self): with self.assertRaises(Exception) as e: self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo')) diff --git a/test/dygraph_to_static/test_backward_without_params.py b/test/dygraph_to_static/test_backward_without_params.py index af70b9e7a2f95f..336d96f2399b53 100644 --- a/test/dygraph_to_static/test_backward_without_params.py +++ b/test/dygraph_to_static/test_backward_without_params.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) import paddle @@ -24,16 +27,16 @@ class Net(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static def forward(self, x): out = x + 1 return out -class TestBackwardWithoutParams(unittest.TestCase): +# @dy2static_unittest +class TestBackwardWithoutParams(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_run(self): - net = Net() + net = paddle.jit.to_static(Net()) x = paddle.ones([2, 2]) x.stop_gradient = False @@ -47,7 +50,6 @@ class ZeroSizeNet(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static def forward(self, x): y = paddle.randn((0,)) out = paddle.nn.functional.relu(x) @@ -55,10 +57,11 @@ def forward(self, x): return y, out -class TestZeroSizeNet(unittest.TestCase): +# @dy2static_unittest +class TestZeroSizeNet(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_run(self): - net = ZeroSizeNet() + net = paddle.jit.to_static(ZeroSizeNet()) x = paddle.ones([2, 2]) x.stop_gradient = False _, out = net(x) diff --git a/test/dygraph_to_static/test_basic_api_transformation.py b/test/dygraph_to_static/test_basic_api_transformation.py index efa9caa17dd515..e0998b8fe1e67f 100644 --- a/test/dygraph_to_static/test_basic_api_transformation.py +++ b/test/dygraph_to_static/test_basic_api_transformation.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base, to_tensor @@ -69,6 +72,7 @@ def dyfunc_bool_to_tensor(x): return paddle.to_tensor(True) +@dy2static_unittest class TestDygraphBasicApi_ToVariable(unittest.TestCase): def setUp(self): self.input = np.ones(5).astype("int32") @@ -230,6 +234,7 @@ def dyfunc_Prelu(input): return res +@dy2static_unittest class TestDygraphBasicApi(unittest.TestCase): # Compare results of dynamic graph and transformed static graph function which only # includes basic Api. @@ -396,6 +401,7 @@ def dyfunc_PolynomialDecay(): return paddle.to_tensor(lr) +@dy2static_unittest class TestDygraphBasicApi_CosineDecay(unittest.TestCase): def setUp(self): self.dygraph_func = dyfunc_CosineDecay @@ -539,6 +545,7 @@ def _dygraph_fn(): np.random.random(1) +@dy2static_unittest class TestDygraphApiRecognition(unittest.TestCase): def setUp(self): self.src = inspect.getsource(_dygraph_fn) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index c7b5272ff47659..ba8e2350794aad 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -20,7 +20,11 @@ import numpy as np from bert_dygraph_model import PretrainModelLayer from bert_utils import get_bert_config, get_feed_data_reader -from dygraph_to_static_util import ast_only_test, test_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_with_new_ir, +) from predictor_utils import PredictorTools import paddle @@ -74,6 +78,7 @@ def __len__(self): return len(self.src_ids) +@dy2static_unittest class TestBert(unittest.TestCase): def setUp(self): self.bert_config = get_bert_config() diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index d3a2162dc787e1..a803c1d4bf49ed 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -205,6 +205,7 @@ def test_optim_break_in_while(x): return x +@dy2static_unittest class TestContinueInFor(unittest.TestCase): def setUp(self): self.input = np.zeros(1).astype('int64') diff --git a/test/dygraph_to_static/test_build_strategy.py b/test/dygraph_to_static/test_build_strategy.py index 83ed8d56751dd9..85e934afb020bb 100644 --- a/test/dygraph_to_static/test_build_strategy.py +++ b/test/dygraph_to_static/test_build_strategy.py @@ -84,6 +84,7 @@ def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': False}) +@dy2static_unittest class TestError(unittest.TestCase): def test_type_error(self): def foo(x): diff --git a/test/dygraph_to_static/test_cache_program.py b/test/dygraph_to_static/test_cache_program.py index 0602b15b3054be..199c3e980e20c9 100644 --- a/test/dygraph_to_static/test_cache_program.py +++ b/test/dygraph_to_static/test_cache_program.py @@ -76,6 +76,7 @@ def setUp(self): self.data = np.random.random((4, 10)).astype('float32') +@dy2static_unittest class TestCacheProgramWithOptimizer(unittest.TestCase): def setUp(self): self.dygraph_class = Linear @@ -125,6 +126,7 @@ def simple_func(x): return mean +@dy2static_unittest class TestConvertWithCache(unittest.TestCase): def test_cache(self): static_func = convert_to_static(simple_func) @@ -155,6 +157,7 @@ def sum_under_while(limit): return ret_sum +@dy2static_unittest class TestToOutputWithCache(unittest.TestCase): def test_output(self): with base.dygraph.guard(): diff --git a/test/dygraph_to_static/test_cast.py b/test/dygraph_to_static/test_cast.py index 7e2b0914a5fff5..a01f2712cc764d 100644 --- a/test/dygraph_to_static/test_cast.py +++ b/test/dygraph_to_static/test_cast.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) from paddle import base from paddle.jit.api import to_static @@ -60,7 +64,8 @@ def test_mix_cast(x): return x -class TestCastBase(unittest.TestCase): +# @dy2static_unittest +class TestCastBase(Dy2StTestBase): def setUp(self): self.place = ( base.CUDAPlace(0) @@ -90,6 +95,7 @@ def do_test(self): @ast_only_test # TODO: add new symbolic only test. @test_and_compare_with_new_ir(False) + # @set_to_static_mode(ToStaticMode.LEGACY_AST) def test_cast_result(self): res = self.do_test().numpy() self.assertTrue( @@ -186,9 +192,11 @@ def prepare(self): def set_func(self): self.func = test_not_var_cast - @ast_only_test # TODO: add new symbolic only test. + @ast_only_test @test_and_compare_with_new_ir(False) def test_cast_result(self): + # breakpoint() + # print("run once!!!") res = self.do_test() self.assertTrue(type(res) == int, msg='The casted dtype is not int.') ref_val = int(self.input) diff --git a/test/dygraph_to_static/test_cinn.py b/test/dygraph_to_static/test_cinn.py index 59a114d0aae586..84e619149c8009 100644 --- a/test/dygraph_to_static/test_cinn.py +++ b/test/dygraph_to_static/test_cinn.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -42,6 +45,7 @@ def apply_to_static(net, use_cinn): return paddle.jit.to_static(net, build_strategy=build_strategy) +@dy2static_unittest class TestCINN(unittest.TestCase): def setUp(self): self.x = paddle.randn([2, 4]) diff --git a/test/dygraph_to_static/test_cinn_prim.py b/test/dygraph_to_static/test_cinn_prim.py index 0bf905ec846f9f..2ed5326f7b9d00 100644 --- a/test/dygraph_to_static/test_cinn_prim.py +++ b/test/dygraph_to_static/test_cinn_prim.py @@ -172,6 +172,7 @@ def test_cinn_prim(self): ) +@dy2static_unittest class TestBackend(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_backend(self): diff --git a/test/dygraph_to_static/test_cinn_prim_layer_norm.py b/test/dygraph_to_static/test_cinn_prim_layer_norm.py index 18c48883d75a68..42bf36d731eca6 100644 --- a/test/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/test/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle import paddle.nn.functional as F @@ -52,6 +52,7 @@ def forward(self, x, w, b): return out[0] +@dy2static_unittest class TestPrimForward(unittest.TestCase): """ This case only tests prim_forward + to_static + cinn. Thus we need to @@ -124,6 +125,7 @@ def test_cinn_prim_forward(self): ) +@dy2static_unittest class TestPrimForwardAndBackward(unittest.TestCase): """ Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph diff --git a/test/dygraph_to_static/test_closure_analysis.py b/test/dygraph_to_static/test_closure_analysis.py index 95234565a6922f..de1d1e12d6502a 100644 --- a/test/dygraph_to_static/test_closure_analysis.py +++ b/test/dygraph_to_static/test_closure_analysis.py @@ -13,10 +13,12 @@ # limitations under the License. import inspect -import os import unittest -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) from numpy import append import paddle @@ -161,7 +163,7 @@ def test_push_pop_4(x, *args, **kargs): return l, k -class TestClosureAnalysis(unittest.TestCase): +class TestClosureAnalysis(Dy2StTestBase): def setUp(self): self.judge_type = "var and w_vars" self.init_dygraph_func() @@ -260,7 +262,7 @@ def init_dygraph_func(self): ] -class TestPushPopTrans(unittest.TestCase): +class TestPushPopTrans(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test(self): def vlist_of_dict(x): @@ -270,7 +272,6 @@ def vlist_of_dict(x): return ma x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -284,7 +285,6 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -298,7 +298,6 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -312,7 +311,6 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) @test_and_compare_with_new_ir(False) @@ -326,10 +324,8 @@ def vlist_of_dict(x): return a x = paddle.to_tensor([3]) - print(paddle.jit.to_static(vlist_of_dict).code) print(paddle.jit.to_static(vlist_of_dict)(x)) if __name__ == '__main__': - os.environ['ENABLE_FALL_BACK'] = "False" unittest.main() diff --git a/test/dygraph_to_static/test_convert_call.py b/test/dygraph_to_static/test_convert_call.py index 77ca5a88f012b1..723d3f910debdd 100644 --- a/test/dygraph_to_static/test_convert_call.py +++ b/test/dygraph_to_static/test_convert_call.py @@ -77,6 +77,7 @@ def dyfunc_with_staticmethod(x_v): return a.add(x_v, x_v) +@dy2static_unittest class TestRecursiveCall1(unittest.TestCase): def setUp(self): self.input = np.random.random([10, 16]).astype('float32') @@ -168,6 +169,7 @@ def forward(self, inputs): return self.act(out) +@dy2static_unittest class TestRecursiveCall2(unittest.TestCase): def setUp(self): self.input = np.random.random((1, 3, 3, 5)).astype('float32') diff --git a/test/dygraph_to_static/test_convert_call_generator.py b/test/dygraph_to_static/test_convert_call_generator.py index b33a41576498db..dd9d93c907c552 100644 --- a/test/dygraph_to_static/test_convert_call_generator.py +++ b/test/dygraph_to_static/test_convert_call_generator.py @@ -14,7 +14,11 @@ import unittest -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle.jit import to_static @@ -32,6 +36,7 @@ def main_func(): print(i) +@dy2static_unittest class TestConvertGenerator(unittest.TestCase): # fallback will ok. @ast_only_test diff --git a/test/dygraph_to_static/test_convert_operators.py b/test/dygraph_to_static/test_convert_operators.py index 420e7d8b1e8871..02d0c09a70857c 100644 --- a/test/dygraph_to_static/test_convert_operators.py +++ b/test/dygraph_to_static/test_convert_operators.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -40,6 +44,7 @@ def forward(self): net.forward = "A string so that convert forward will fail" +@dy2static_unittest class TestConvertCall(unittest.TestCase): # fallback mode will raise a InnerError, it's ok. @ast_only_test @@ -68,6 +73,7 @@ def callable_list(x, y): self.assertEqual(callable_list(1, 2), 3) +@dy2static_unittest class TestConvertShapeCompare(unittest.TestCase): def test_non_variable(self): self.assertEqual( @@ -204,6 +210,7 @@ def forward(self, x): return out +@dy2static_unittest class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_tensor_shape(self): @@ -214,6 +221,7 @@ def test_tensor_shape(self): np.testing.assert_array_equal(out.numpy(), x.numpy()) +@dy2static_unittest class TestIfElseNoValue(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_else_ret_none(self): diff --git a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py index f5d6c833d16c1c..b6e55b8900c1e8 100644 --- a/test/dygraph_to_static/test_cpu_cuda_to_tensor.py +++ b/test/dygraph_to_static/test_cpu_cuda_to_tensor.py @@ -25,6 +25,7 @@ import paddle +@dy2static_unittest class TestCpuCuda(unittest.TestCase): def test_cpu_cuda(self): def func(x): @@ -38,6 +39,7 @@ def func(x): # print(paddle.jit.to_static(func)(x)) +@dy2static_unittest class TestToTensor(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_to_tensor_with_variable_list(self): diff --git a/test/dygraph_to_static/test_cycle_gan.py b/test/dygraph_to_static/test_cycle_gan.py index 3484b27d5fac5e..fb06a52407ec61 100644 --- a/test/dygraph_to_static/test_cycle_gan.py +++ b/test/dygraph_to_static/test_cycle_gan.py @@ -12,16 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License import os import random @@ -36,7 +26,10 @@ # Use GPU:0 to elimate the influence of other tasks. os.environ["CUDA_VISIBLE_DEVICES"] = "1" -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle.base.dygraph import to_variable @@ -686,6 +679,7 @@ def train(args, to_static): return np.array(loss_data) +@dy2static_unittest class TestCycleGANModel(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_declarative.py b/test/dygraph_to_static/test_declarative.py index 9d3e1e54b0ebb5..12b098cc10ac56 100644 --- a/test/dygraph_to_static/test_declarative.py +++ b/test/dygraph_to_static/test_declarative.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, + test_and_compare_with_new_ir, +) from test_basic_api_transformation import dyfunc_to_variable import paddle @@ -31,8 +35,6 @@ from paddle.nn import Layer from paddle.static import InputSpec -os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only - class SimpleNet(Layer): def __init__(self): @@ -89,7 +91,7 @@ def func_with_list_dict(self, dl): return z -class TestStaticFunctionInstance(unittest.TestCase): +class TestStaticFunctionInstance(Dy2StTestBase): def test_instance_same_class(self): with base.dygraph.guard(base.CPUPlace()): net_1 = SimpleNet() @@ -106,7 +108,7 @@ def test_instance_same_class(self): self.assertTrue(len(net_2.forward.program_cache) == 0) -class TestInputSpec(unittest.TestCase): +class TestInputSpec(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join(self.temp_dir.name, 'simple_net') @@ -115,6 +117,7 @@ def tearDown(self): self.temp_dir.cleanup() @test_and_compare_with_new_ir(False) + @ast_only_test def test_with_input_spec(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -175,6 +178,7 @@ def test_with_error(self): ) net.add_func(x, y) + @ast_only_test def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): x = to_variable(np.ones([4, 10]).astype('float32')) @@ -210,11 +214,12 @@ def foo_func(a, b, c=1, d=2): return z -class TestDifferentInputSpecCacheProgram(unittest.TestCase): +class TestDifferentInputSpecCacheProgram(Dy2StTestBase): def setUp(self): paddle.jit.enable_to_static(True) @test_and_compare_with_new_ir(False) + @ast_only_test def test_with_different_input(self): with base.dygraph.guard(base.CPUPlace()): x_data = np.ones([16, 10]).astype('float32') @@ -260,6 +265,7 @@ def test_with_different_input(self): recent_program = foo.program_cache.last() self.assertTrue(first_program == recent_program) + @ast_only_test def test_get_concrete_program(self): foo = to_static(foo_func) @@ -301,6 +307,7 @@ def test_get_concrete_program(self): ) @test_and_compare_with_new_ir(False) + @ast_only_test def test_concrete_program(self): with base.dygraph.guard(base.CPUPlace()): # usage 1 @@ -324,7 +331,7 @@ def test_concrete_program(self): foo_3.concrete_program # noqa: B018 -class TestInputDefaultName(unittest.TestCase): +class TestInputDefaultName(Dy2StTestBase): def setUp(self): paddle.disable_static() self.net = SimpleNet() @@ -348,7 +355,8 @@ def test_nest_input(self): self.assert_default_name('func_with_list_dict', ['dl_0', 'x', 'y']) -class TestDeclarativeAPI(unittest.TestCase): +class TestDeclarativeAPI(Dy2StTestBase): + @ast_only_test def test_error(self): func = to_static(dyfunc_to_variable) @@ -366,19 +374,21 @@ def test_error(self): func(np.ones(5).astype("int32")) -class TestDecorateModelDirectly(unittest.TestCase): +class TestDecorateModelDirectly(Dy2StTestBase): def setUp(self): paddle.disable_static() paddle.jit.enable_to_static(True) self.x = to_variable(np.ones([4, 10]).astype('float32')) @test_and_compare_with_new_ir(False) + @ast_only_test def test_fake_input(self): net = SimpleNet() net = to_static(net) y = net(self.x) self.assertTrue(len(net.forward.program_cache) == 1) + @ast_only_test def test_input_spec(self): net = SimpleNet() net = to_static(net, input_spec=[InputSpec([None, 8, 10])]) @@ -393,7 +403,7 @@ def test_input_spec(self): self.assertListEqual(list(input_shape), [-1, 16, 10]) -class TestErrorWithInitFromStaticMode(unittest.TestCase): +class TestErrorWithInitFromStaticMode(Dy2StTestBase): def test_raise_error(self): # disable imperative paddle.enable_static() @@ -435,7 +445,7 @@ def func(self): return x -class TestCallNonForwardFunc(unittest.TestCase): +class TestCallNonForwardFunc(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_call_non_forward(self): paddle.disable_static() @@ -468,7 +478,7 @@ def forward(self): return self.b -class TestSetBuffers(unittest.TestCase): +class TestSetBuffers(Dy2StTestBase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.model_path = os.path.join(self.temp_dir.name, 'SetBuffersNet1') @@ -485,6 +495,7 @@ def test_set_buffers1(self): paddle.jit.save(net, self.model_path) paddle.enable_static() + @ast_only_test def test_set_buffers2(self): paddle.disable_static() net = SetBuffersNet2() @@ -498,7 +509,7 @@ def func(self, x): return x + 1 -class TestClassNoInheritLayer(unittest.TestCase): +class TestClassNoInheritLayer(Dy2StTestBase): def test_to_static(self): paddle.disable_static() net = ClassNoInheritLayer() diff --git a/test/dygraph_to_static/test_decorator_transform.py b/test/dygraph_to_static/test_decorator_transform.py index d0ddffdd40cbe7..4f4096d607dc8a 100644 --- a/test/dygraph_to_static/test_decorator_transform.py +++ b/test/dygraph_to_static/test_decorator_transform.py @@ -19,9 +19,9 @@ import decos import numpy as np -from dygraph_to_static_util import ( +from dygraph_to_static_utils_new import ( + Dy2StTestBase, ast_only_test, - dy2static_unittest, test_and_compare_with_new_ir, ) @@ -185,8 +185,7 @@ def deco_with_paddle_api(): return fun10() -@dy2static_unittest -class TestDecoratorTransform(unittest.TestCase): +class TestDecoratorTransform(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_deco_transform(self): outs = paddle.jit.to_static(forward)() diff --git a/test/dygraph_to_static/test_deepcopy.py b/test/dygraph_to_static/test_deepcopy.py index 0959d74dbc1fbf..82ffeaf9f2290c 100644 --- a/test/dygraph_to_static/test_deepcopy.py +++ b/test/dygraph_to_static/test_deepcopy.py @@ -16,14 +16,18 @@ from copy import deepcopy import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) from test_rollback import Net, foo import paddle from paddle.jit.dy2static.program_translator import StaticFunction -class TestDeepCopy(unittest.TestCase): +# @dy2static_unittest +class TestDeepCopy(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_net(self): net = Net() diff --git a/test/dygraph_to_static/test_dict.py b/test/dygraph_to_static/test_dict.py index 80180b522cf540..99364c1343a7d6 100644 --- a/test/dygraph_to_static/test_dict.py +++ b/test/dygraph_to_static/test_dict.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -116,6 +119,7 @@ def update_cache(cache): return cache +@dy2static_unittest class TestNetWithDict(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` @@ -169,6 +173,7 @@ def test_dic_pop_2(x): return out +@dy2static_unittest class TestDictPop(unittest.TestCase): def setUp(self): self.input = np.random.random(3).astype('int32') @@ -249,6 +254,7 @@ def test_ast_to_func(self): ) +@dy2static_unittest class TestDictCmpInFor(unittest.TestCase): def test_with_for(self): def func(): diff --git a/test/dygraph_to_static/test_drop_path.py b/test/dygraph_to_static/test_drop_path.py index a9ea20be04c383..aad752007ceb0c 100644 --- a/test/dygraph_to_static/test_drop_path.py +++ b/test/dygraph_to_static/test_drop_path.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -36,6 +39,7 @@ def forward(self, x): return drop_path(x, self.training) +@dy2static_unittest class TestTrainEval(unittest.TestCase): def setUp(self): self.model = DropPath() diff --git a/test/dygraph_to_static/test_duplicate_output.py b/test/dygraph_to_static/test_duplicate_output.py index 7e4220899d5eff..add3a7262446ae 100644 --- a/test/dygraph_to_static/test_duplicate_output.py +++ b/test/dygraph_to_static/test_duplicate_output.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -38,6 +41,7 @@ def forward(self, x): return x, x +@dy2static_unittest class TestDuplicateOutput(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` diff --git a/test/dygraph_to_static/test_error.py b/test/dygraph_to_static/test_error.py index 8c6f74d75c4e0b..c12dc3887f23d7 100644 --- a/test/dygraph_to_static/test_error.py +++ b/test/dygraph_to_static/test_error.py @@ -23,8 +23,6 @@ from paddle.jit.dy2static import error from paddle.jit.dy2static.origin_info import unwrap -os.environ['ENABLE_FALL_BACK'] = "False" # NOTE: ast only - def inner_func(): paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int") @@ -257,9 +255,9 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 37, in func_error_in_compile_time', + f'File "{self.filepath}", line 35, in func_error_in_compile_time', 'inner_func()', - f'File "{self.filepath}", line 30, in inner_func', + f'File "{self.filepath}", line 28, in inner_func', 'def inner_func():', 'paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', '<--- HERE', @@ -286,7 +284,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 48, in func_error_in_compile_time_2', + f'File "{self.filepath}", line 46, in func_error_in_compile_time_2', 'def func_error_in_compile_time_2(x):', 'x = base.dygraph.to_variable(x)', 'x = paddle.reshape(x, shape=[1, 2])', @@ -310,7 +308,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 93, in forward', + f'File "{self.filepath}", line 91, in forward', '@paddle.jit.to_static', 'def forward(self):', 'self.test_func()', @@ -334,7 +332,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 56, in func_error_in_runtime', + f'File "{self.filepath}", line 54, in func_error_in_runtime', 'x = base.dygraph.to_variable(x)', 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', 'x = paddle.reshape(x, shape=[1, two])', @@ -349,7 +347,7 @@ def set_func(self): def set_message(self): self.expected_message = [ - 'File "{}", line 108, in func_error_in_runtime_with_empty_line'.format( + 'File "{}", line 106, in func_error_in_runtime_with_empty_line'.format( self.filepath ), 'two = paddle.tensor.fill_constant(shape=[1], value=2, dtype="int32")', @@ -372,7 +370,7 @@ def set_exception_type(self): def set_message(self): self.expected_message = [ - f'File "{self.filepath}", line 82, in forward', + f'File "{self.filepath}", line 80, in forward', 'def forward(self, x):', 'y = self._linear(x)', 'z = paddle.tensor.fill_constant(shape=[1, 2], value=9, dtype="int")', diff --git a/test/dygraph_to_static/test_fallback.py b/test/dygraph_to_static/test_fallback.py index b641f8b22233ad..58394feda2a680 100644 --- a/test/dygraph_to_static/test_fallback.py +++ b/test/dygraph_to_static/test_fallback.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle @@ -51,6 +51,7 @@ def forward(self, x): return unsupport_func(x - 1) +@dy2static_unittest class TestFallback(unittest.TestCase): def setUp(self): self.x = paddle.to_tensor([2]).astype('int') diff --git a/test/dygraph_to_static/test_fetch_feed.py b/test/dygraph_to_static/test_fetch_feed.py index 0834f2ec4a315e..b44578fad2c9e3 100644 --- a/test/dygraph_to_static/test_fetch_feed.py +++ b/test/dygraph_to_static/test_fetch_feed.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -62,6 +65,7 @@ def forward(self, x): return pre, loss +@dy2static_unittest class TestPool2D(unittest.TestCase): def setUp(self): self.dygraph_class = Pool2D diff --git a/test/dygraph_to_static/test_for_enumerate.py b/test/dygraph_to_static/test_for_enumerate.py index bbb64e8756ea33..dc9505a5cf6fcc 100644 --- a/test/dygraph_to_static/test_for_enumerate.py +++ b/test/dygraph_to_static/test_for_enumerate.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -353,6 +354,7 @@ def tensor_array_slice_in_enumerate(): return feat_n2 +@dy2static_unittest class TestTransformBase(unittest.TestCase): def setUp(self): self.place = ( @@ -556,6 +558,7 @@ def test_transformed_result_compare(self): self.transformed_result_compare() +@dy2static_unittest class TestForZip(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_full_name_usage.py b/test/dygraph_to_static/test_full_name_usage.py index 0332480891e166..39a80acb566ea2 100644 --- a/test/dygraph_to_static/test_full_name_usage.py +++ b/test/dygraph_to_static/test_full_name_usage.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle import base @@ -58,6 +58,7 @@ def double_decorated_func2(self, x): return jit_decorated_func(x) +@dy2static_unittest class TestFullNameDecorator(unittest.TestCase): @ast_only_test def test_run_success(self): diff --git a/test/dygraph_to_static/test_grad.py b/test/dygraph_to_static/test_grad.py index e542d87efc90ce..ceca09e7895486 100644 --- a/test/dygraph_to_static/test_grad.py +++ b/test/dygraph_to_static/test_grad.py @@ -65,6 +65,7 @@ def forward(self, x): return out +@dy2static_unittest class TestGrad(unittest.TestCase): def setUp(self): self.func = paddle.jit.to_static(GradLayer()) diff --git a/test/dygraph_to_static/test_gradient_aggregation.py b/test/dygraph_to_static/test_gradient_aggregation.py index ab7effba5b16c6..4172fb87197df7 100644 --- a/test/dygraph_to_static/test_gradient_aggregation.py +++ b/test/dygraph_to_static/test_gradient_aggregation.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -37,6 +40,7 @@ def forward(self, x): # return [out2, out1] # 梯度正常 +@dy2static_unittest class TestGradientAggregationInDy2Static(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_to_static(self): diff --git a/test/dygraph_to_static/test_grid_generator.py b/test/dygraph_to_static/test_grid_generator.py index ea1eafb5c1fa9f..7c1a9189366e0e 100644 --- a/test/dygraph_to_static/test_grid_generator.py +++ b/test/dygraph_to_static/test_grid_generator.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) import paddle from paddle import ParamAttr, nn @@ -126,7 +129,7 @@ def get_expand_tensor(self, batch_C_prime): return batch_C_ex_part_tensor -class TestGridGenerator(unittest.TestCase): +class TestGridGenerator(Dy2StTestBase): def setUp(self): self.x = paddle.uniform(shape=[1, 20, 2], dtype='float32') diff --git a/test/dygraph_to_static/test_ifelse.py b/test/dygraph_to_static/test_ifelse.py index 6e2dc6f8ffe6df..12db665b8c822a 100644 --- a/test/dygraph_to_static/test_ifelse.py +++ b/test/dygraph_to_static/test_ifelse.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, dy2static_unittest +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) from ifelse_simple_func import ( NetWithControlFlowIf, add_fn, @@ -63,6 +67,7 @@ def setUp(self): self.error = "Your if/else have different number of return value." @ast_only_test + @test_and_compare_with_new_ir() def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -72,6 +77,7 @@ def test_error(self): paddle.jit.enable_to_static(False) +@dy2static_unittest class TestDygraphIfElse(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` @@ -94,6 +100,7 @@ def _run_dygraph(self, to_static=False): ret = self.dyfunc(x_v) return ret.numpy() + @test_and_compare_with_new_ir() def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -122,11 +129,28 @@ def setUp(self): self.dyfunc = dyfunc_with_if_else_with_list_generator -class TestDygraphNestedIfElse(TestDygraphIfElse): +@dy2static_unittest +class TestDygraphNestedIfElse(unittest.TestCase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = nested_if_else + def _run_static(self): + return self._run_dygraph(to_static=True) + + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = paddle.to_tensor(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + # TODO(zhangbo): open pir test (sub block cannot find var in parent block) + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class TestDygraphNestedIfElse2(TestDygraphIfElse): def setUp(self): @@ -232,12 +256,30 @@ def setUp(self): self.dyfunc = if_with_class_var -class TestDygraphIfTensor(TestDygraphIfElse): +@dy2static_unittest +class TestDygraphIfTensor(unittest.TestCase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.dyfunc = if_tensor_case + def _run_static(self): + return self._run_dygraph(to_static=True) + def _run_dygraph(self, to_static=False): + with base.dygraph.guard(place): + x_v = paddle.to_tensor(self.x) + if to_static: + ret = paddle.jit.to_static(self.dyfunc)(x_v) + else: + ret = self.dyfunc(x_v) + return ret.numpy() + + # TODO(zhangbo): open pir test (abnormal insertion of fill constant op after conditional block op) + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + + +@dy2static_unittest class TestDygraphIfElseNet(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` @@ -263,6 +305,7 @@ def _run(self, to_static=False): ret = net(x_v) return ret.numpy() + # TODO(zhangbo): open pir test (sub block cannot find var in parent block) def test_ast_to_func(self): self.assertTrue((self._run_dygraph() == self._run_static()).all()) @@ -316,6 +359,10 @@ def setUp(self): self.x = np.random.random([10, 16]).astype('float32') self.Net = NetWithExternalFunc + @test_and_compare_with_new_ir() + def test_ast_to_func(self): + self.assertTrue((self._run_dygraph() == self._run_static()).all()) + class DiffModeNet1(paddle.nn.Layer): def __init__(self, mode): @@ -350,6 +397,7 @@ def forward(self, x, y): raise ValueError('Illegal mode') +@dy2static_unittest class TestDiffModeNet(unittest.TestCase): """ TestCase for the net with different modes @@ -370,6 +418,7 @@ def _run(self, mode, to_static): ret = net(self.x, self.y) return ret.numpy() + @test_and_compare_with_new_ir() def test_train_mode(self): self.assertTrue( ( @@ -378,6 +427,7 @@ def test_train_mode(self): ).all() ) + @test_and_compare_with_new_ir() def test_infer_mode(self): self.assertTrue( ( @@ -392,7 +442,9 @@ def init_net(self): self.Net = DiffModeNet2 +@dy2static_unittest class TestNewVarCreateInOneBranch(unittest.TestCase): + @test_and_compare_with_new_ir() def test_var_used_in_another_for(self): def case_func(training): # targets and targets_list is dynamically defined by training @@ -430,6 +482,7 @@ def get_dy2stat_out(self): return out @ast_only_test + @test_and_compare_with_new_ir() def test_ast_to_func(self): self.setUp() self.assertIsInstance(self.out[0], (paddle.Tensor, core.eager.Tensor)) @@ -450,6 +503,7 @@ def setUp(self): self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int3) self.out = self.get_dy2stat_out() + # TODO(zhangbo): open pir test (abnormal insertion of fill constant op after conditional block op) @ast_only_test def test_ast_to_func(self): self.setUp() @@ -463,6 +517,7 @@ def setUp(self): self.dyfunc = paddle.jit.to_static(dyfunc_ifelse_ret_int4) @ast_only_test + @test_and_compare_with_new_ir() def test_ast_to_func(self): paddle.jit.enable_to_static(True) with self.assertRaises(Dygraph2StaticException): @@ -497,7 +552,9 @@ def forward(self, a, b, c): return b +@dy2static_unittest class TestDy2StIfElseBackward(unittest.TestCase): + # TODO(zhangbo): open pir test (IfOp grad execution not yet supported) def test_run_backward(self): a = paddle.randn((4, 3), dtype='float32') a.stop_gradient = False diff --git a/test/dygraph_to_static/test_isinstance.py b/test/dygraph_to_static/test_isinstance.py index e3557dc32658f9..7dfd05989dabe8 100644 --- a/test/dygraph_to_static/test_isinstance.py +++ b/test/dygraph_to_static/test_isinstance.py @@ -26,7 +26,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import nn @@ -85,6 +88,7 @@ def train(model, to_static): return out.numpy() +@dy2static_unittest class TestIsinstance(unittest.TestCase): def test_isinstance_simple_return_layer(self): model = IsInstanceLayer(SimpleReturnLayer()) diff --git a/test/dygraph_to_static/test_jit_property_save.py b/test/dygraph_to_static/test_jit_property_save.py index f25c128e265d7a..965168dedc6ea0 100644 --- a/test/dygraph_to_static/test_jit_property_save.py +++ b/test/dygraph_to_static/test_jit_property_save.py @@ -14,9 +14,12 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle +@dy2static_unittest class TestPropertySave(unittest.TestCase): """test jit property save""" diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 59841ed431f086..219e6a6c9de749 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -16,11 +16,13 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle import paddle.nn.functional as F +@dy2static_unittest class TestSetItemBase(unittest.TestCase): def setUp(self) -> None: pass diff --git a/test/dygraph_to_static/test_lac.py b/test/dygraph_to_static/test_lac.py index 522eb81cf5a7ae..461b03fe7a5edc 100644 --- a/test/dygraph_to_static/test_lac.py +++ b/test/dygraph_to_static/test_lac.py @@ -22,6 +22,8 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "2" +from dygraph_to_static_util import dy2static_unittest + import paddle from paddle import _legacy_C_ops, base from paddle.base.dygraph import to_variable @@ -513,6 +515,7 @@ def create_dataloader(reader, place): return data_loader +@dy2static_unittest class TestLACModel(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_lambda.py b/test/dygraph_to_static/test_lambda.py index c1ff57147564c5..add572cb6dfcff 100644 --- a/test/dygraph_to_static/test_lambda.py +++ b/test/dygraph_to_static/test_lambda.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle import paddle.nn.functional as F @@ -79,6 +80,7 @@ def call_lambda_with_ifExpr2(x): return out +@dy2static_unittest class TestLambda(unittest.TestCase): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') diff --git a/test/dygraph_to_static/test_layer_hook.py b/test/dygraph_to_static/test_layer_hook.py index bf679cf8dcc2e4..d19b9ea9abfc94 100644 --- a/test/dygraph_to_static/test_layer_hook.py +++ b/test/dygraph_to_static/test_layer_hook.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -56,6 +59,7 @@ def forward(self, x): return out +@dy2static_unittest class TestNestLayerHook(unittest.TestCase): def setUp(self): paddle.seed(2022) diff --git a/test/dygraph_to_static/test_len.py b/test/dygraph_to_static/test_len.py index e2cee7c4dc8b44..340ba86ff50c2f 100644 --- a/test/dygraph_to_static/test_len.py +++ b/test/dygraph_to_static/test_len.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -42,6 +43,7 @@ def len_with_lod_tensor_array(x): return arr_len +@dy2static_unittest class TestLen(unittest.TestCase): def setUp(self): self.place = ( @@ -113,6 +115,7 @@ def len_with_selected_rows(place): return result +@dy2static_unittest class TestLenWithSelectedRows(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_list.py b/test/dygraph_to_static/test_list.py index 9ad646de8818c9..51b28ce3fe38a7 100644 --- a/test/dygraph_to_static/test_list.py +++ b/test/dygraph_to_static/test_list.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -207,6 +208,7 @@ def test_list_pop_in_while_loop(x, iter_num): return a[0], b[2] +@dy2static_unittest class TestListWithoutControlFlow(unittest.TestCase): def setUp(self): self.place = ( @@ -354,6 +356,7 @@ def forward(self, x, index, *args): return z +@dy2static_unittest class TestListWithCondGradInferVarType(unittest.TestCase): def test_to_static(self): net = ListWithCondNet() diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 95e06a51f3c692..81a45fb91cc4ef 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + test_and_compare_with_new_ir, +) import paddle @@ -41,7 +44,7 @@ def forward(self, x): return t -class TestFallback(unittest.TestCase): +class TestFallback(Dy2StTestBase): def setUp(self): self.x = paddle.to_tensor(1.0).astype('int') @@ -54,7 +57,7 @@ def test_name_load(self): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) -class TestLoad2(unittest.TestCase): +class TestLoad2(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_name_load_nograd(self): @paddle.no_grad() diff --git a/test/dygraph_to_static/test_logical.py b/test/dygraph_to_static/test_logical.py index 9e0f1d12bd9b48..a05f91b7c04932 100644 --- a/test/dygraph_to_static/test_logical.py +++ b/test/dygraph_to_static/test_logical.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import base @@ -167,6 +168,7 @@ def test_shape_not_equal(x): return paddle.ones([1, 2, 3]) +@dy2static_unittest class TestLogicalBase(unittest.TestCase): def setUp(self): self.input = np.array([3]).astype('int32') @@ -262,6 +264,7 @@ def _set_test_func(self): self.dygraph_func = test_shape_not_equal +@dy2static_unittest class TestCmpopNodeToStr(unittest.TestCase): def test_exception(self): with self.assertRaises(KeyError): diff --git a/test/dygraph_to_static/test_loop.py b/test/dygraph_to_static/test_loop.py index 77f568e2c5eec9..422508d6cd97e8 100644 --- a/test/dygraph_to_static/test_loop.py +++ b/test/dygraph_to_static/test_loop.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle import paddle.nn.functional as F @@ -229,6 +230,7 @@ def for_loop_dufunc_with_listcomp(array): return res +@dy2static_unittest class TestNameVisitor(unittest.TestCase): def setUp(self): self.loop_funcs = [ @@ -299,6 +301,7 @@ def test_nested_loop_vars(self): i += 1 +@dy2static_unittest class TestTransformWhileLoop(unittest.TestCase): def setUp(self): self.place = ( @@ -378,6 +381,7 @@ def _init_dyfunc(self): self.dyfunc = loop_var_contains_property +@dy2static_unittest class TestTransformForLoop(unittest.TestCase): def setUp(self): self.place = ( @@ -460,6 +464,7 @@ def forward(self, x): return out +@dy2static_unittest class TestForLoopMeetDict(unittest.TestCase): def test_start(self): net = Net() diff --git a/test/dygraph_to_static/test_mnist.py b/test/dygraph_to_static/test_mnist.py index 9641a9225cee7b..984176a83afe01 100644 --- a/test/dygraph_to_static/test_mnist.py +++ b/test/dygraph_to_static/test_mnist.py @@ -18,7 +18,11 @@ from time import time import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) from predictor_utils import PredictorTools import paddle @@ -126,6 +130,7 @@ def inference(self, inputs): return x +@dy2static_unittest class TestMNIST(unittest.TestCase): def setUp(self): self.epoch_num = 1 diff --git a/test/dygraph_to_static/test_mobile_net.py b/test/dygraph_to_static/test_mobile_net.py index 5536a14e695c48..cca77999d5e7d9 100644 --- a/test/dygraph_to_static/test_mobile_net.py +++ b/test/dygraph_to_static/test_mobile_net.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import test_with_new_ir +from dygraph_to_static_util import dy2static_unittest, test_with_new_ir from predictor_utils import PredictorTools import paddle @@ -656,6 +656,7 @@ def predict_analysis_inference(args, data): return out +@dy2static_unittest class TestMobileNet(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_multi_forward.py b/test/dygraph_to_static/test_multi_forward.py index 039db089b5c86b..2cf8e592f3fa0f 100644 --- a/test/dygraph_to_static/test_multi_forward.py +++ b/test/dygraph_to_static/test_multi_forward.py @@ -14,7 +14,10 @@ import unittest -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -33,6 +36,7 @@ def forward(self, x): return self.linear(x) +@dy2static_unittest class TestBackward(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_order_0(self): diff --git a/test/dygraph_to_static/test_new_ir_selectedrows.py b/test/dygraph_to_static/test_new_ir_selectedrows.py index 7d87a48fe78585..e403cbd6089a13 100644 --- a/test/dygraph_to_static/test_new_ir_selectedrows.py +++ b/test/dygraph_to_static/test_new_ir_selectedrows.py @@ -15,10 +15,7 @@ import random import unittest -from dygraph_to_static_util import ( - enable_fallback_guard, - test_and_compare_with_new_ir, -) +from dygraph_to_static_util import test_and_compare_with_new_ir import paddle from paddle.jit.api import to_static @@ -104,5 +101,4 @@ def test_dygraph_static_same_loss(self): if __name__ == '__main__': - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_op_attr.py b/test/dygraph_to_static/test_op_attr.py index 17394df88dd071..6aaf1cdbf21385 100644 --- a/test/dygraph_to_static/test_op_attr.py +++ b/test/dygraph_to_static/test_op_attr.py @@ -14,7 +14,7 @@ import unittest -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle.static import InputSpec @@ -52,6 +52,7 @@ def with_cond(self, x): return out +@dy2static_unittest class CheckOpAttr(unittest.TestCase): def setUp(self): self.in_num = 16 diff --git a/test/dygraph_to_static/test_origin_info.py b/test/dygraph_to_static/test_origin_info.py index e2925d4fa1a4bd..be38650b750c21 100644 --- a/test/dygraph_to_static/test_origin_info.py +++ b/test/dygraph_to_static/test_origin_info.py @@ -16,6 +16,8 @@ import sys import unittest +from dygraph_to_static_util import dy2static_unittest + from paddle.jit.api import to_static from paddle.jit.dy2static import DygraphToStaticAst from paddle.jit.dy2static.origin_info import ( @@ -54,6 +56,7 @@ def decorated_func2(x): return x +@dy2static_unittest class TestOriginInfo(unittest.TestCase): def setUp(self): self.set_test_func() diff --git a/test/dygraph_to_static/test_param_guard.py b/test/dygraph_to_static/test_param_guard.py index b8edaf50dfceda..c6787db58fc890 100644 --- a/test/dygraph_to_static/test_param_guard.py +++ b/test/dygraph_to_static/test_param_guard.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle.jit import to_static @@ -50,6 +53,7 @@ def forward(self, x): return out +@dy2static_unittest class TestParameterList(unittest.TestCase): def setUp(self): self.seed = 2021 @@ -102,6 +106,7 @@ def forward(self, x): return out +@dy2static_unittest class TestRawParameterList(unittest.TestCase): def setUp(self): self.seed = 2021 diff --git a/test/dygraph_to_static/test_params_no_grad.py b/test/dygraph_to_static/test_params_no_grad.py index f7bf87888f49cd..3b3f3949fad57c 100644 --- a/test/dygraph_to_static/test_params_no_grad.py +++ b/test/dygraph_to_static/test_params_no_grad.py @@ -14,6 +14,8 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle import paddle.distributed as dist from paddle import nn @@ -52,6 +54,7 @@ def train(): print(loss) +@dy2static_unittest class TestParamsNoGrad(unittest.TestCase): def test_two_card(self): if ( diff --git a/test/dygraph_to_static/test_partial_program.py b/test/dygraph_to_static/test_partial_program.py index db4a7c21e40100..7d1c2f0d1ae083 100644 --- a/test/dygraph_to_static/test_partial_program.py +++ b/test/dygraph_to_static/test_partial_program.py @@ -15,9 +15,9 @@ import unittest import numpy as np -from dygraph_to_static_util import ( +from dygraph_to_static_utils_new import ( + Dy2StTestBase, ast_only_test, - dy2static_unittest, test_and_compare_with_new_ir, ) from test_fetch_feed import Linear @@ -57,8 +57,7 @@ def fake_data(shape): return base.dygraph.to_variable(x_data) -@dy2static_unittest -class TestWithNestedInput(unittest.TestCase): +class TestWithNestedInput(Dy2StTestBase): def setUp(self): self.x = None self.y = None @@ -95,8 +94,7 @@ def test_nest(self): np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05) -@dy2static_unittest -class TestWithNestedOutput(unittest.TestCase): +class TestWithNestedOutput(Dy2StTestBase): def setUp(self): self.x = None self.y = None @@ -133,8 +131,7 @@ def test_nest(self): self.assertTrue(dy_var, st_var) -@dy2static_unittest -class TestWithTrainAndEval(unittest.TestCase): +class TestWithTrainAndEval(Dy2StTestBase): @ast_only_test @test_and_compare_with_new_ir(False) def test_switch_eval_and_train(self): @@ -167,8 +164,7 @@ def test_switch_eval_and_train(self): ) -@dy2static_unittest -class TestWithNoGrad(unittest.TestCase): +class TestWithNoGrad(Dy2StTestBase): @ast_only_test @test_and_compare_with_new_ir(False) def test_with_no_grad(self): @@ -204,8 +200,7 @@ def forward(self, x): return x1 -@dy2static_unittest -class TestPruneUnusedParamInProgram(unittest.TestCase): +class TestPruneUnusedParamInProgram(Dy2StTestBase): @test_and_compare_with_new_ir(False) def test_prune(self): input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32") diff --git a/test/dygraph_to_static/test_partial_program_hook.py b/test/dygraph_to_static/test_partial_program_hook.py index cb177862692d30..c10194f6187adf 100644 --- a/test/dygraph_to_static/test_partial_program_hook.py +++ b/test/dygraph_to_static/test_partial_program_hook.py @@ -15,11 +15,14 @@ import os import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle from paddle.base import core from paddle.jit.dy2static import partial_program, program_translator +@dy2static_unittest class TestPartiaProgramLayerHook(unittest.TestCase): def setUp(self): os.environ["ENABLE_FALL_BACK"] = "False" @@ -35,6 +38,7 @@ def test_after_infer(self): self.assertIsNone(self._hook.after_infer(None)) +@dy2static_unittest class TestPrimHook(unittest.TestCase): def setUp(self): os.environ["ENABLE_FALL_BACK"] = "False" diff --git a/test/dygraph_to_static/test_place.py b/test/dygraph_to_static/test_place.py index 2ed904a0b54902..f1cb7e80589a31 100644 --- a/test/dygraph_to_static/test_place.py +++ b/test/dygraph_to_static/test_place.py @@ -14,9 +14,12 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + import paddle +@dy2static_unittest class TestPlace(unittest.TestCase): def test_place(self): paddle.enable_static() diff --git a/test/dygraph_to_static/test_print.py b/test/dygraph_to_static/test_print.py index d7fe1f5a882c07..251bca776e700b 100644 --- a/test/dygraph_to_static/test_print.py +++ b/test/dygraph_to_static/test_print.py @@ -15,7 +15,10 @@ import unittest import numpy -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -84,6 +87,7 @@ def dyfunc_print_with_kwargs(x): print("Tensor", x_t, end='\n\n', sep=': ') +@dy2static_unittest class TestPrintBase(unittest.TestCase): def setUp(self): self.input = numpy.ones(5).astype("int32") diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index 25cf316dd7e91c..d2909d07a50b2f 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -18,7 +18,7 @@ import astor import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2, @@ -212,6 +212,7 @@ def forward(self, x): return y +@dy2static_unittest class TestEnableDeclarative(unittest.TestCase): def setUp(self): self.x = np.random.randn(30, 10, 32).astype('float32') @@ -267,6 +268,7 @@ def switch_mode_function(): return True +@dy2static_unittest class TestFunctionTrainEvalMode(unittest.TestCase): @ast_only_test def test_switch_mode(self): @@ -297,6 +299,7 @@ def test_raise_error(self): net.foo.train() +@dy2static_unittest class TestIfElseEarlyReturn(unittest.TestCase): def test_ifelse_early_return1(self): answer = np.zeros([2, 2]) + 1 @@ -311,6 +314,7 @@ def test_ifelse_early_return2(self): np.testing.assert_allclose(answer, out[0].numpy(), rtol=1e-05) +@dy2static_unittest class TestRemoveCommentInDy2St(unittest.TestCase): def func_with_comment(self): # Comment1 @@ -352,6 +356,7 @@ def func1(x): return func1(data) +@dy2static_unittest class TestParameterRecorder(unittest.TestCase): def test_recorder(self): """function calls nn.Layer case.""" diff --git a/test/dygraph_to_static/test_ptb_lm.py b/test/dygraph_to_static/test_ptb_lm.py index 2c94d6b343d3a8..76a35d57ac9baf 100644 --- a/test/dygraph_to_static/test_ptb_lm.py +++ b/test/dygraph_to_static/test_ptb_lm.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -321,6 +324,7 @@ def train_static(place): return train(place) +@dy2static_unittest class TestPtb(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_ptb_lm_v2.py b/test/dygraph_to_static/test_ptb_lm_v2.py index 3694d503965361..92d4d43d9d4ea2 100644 --- a/test/dygraph_to_static/test_ptb_lm_v2.py +++ b/test/dygraph_to_static/test_ptb_lm_v2.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle @@ -322,6 +323,7 @@ def train_static(place): return train(place) +@dy2static_unittest class TestPtb(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_pylayer.py b/test/dygraph_to_static/test_pylayer.py index c36bc1a14d5d14..0e083a67b0e94d 100644 --- a/test/dygraph_to_static/test_pylayer.py +++ b/test/dygraph_to_static/test_pylayer.py @@ -26,6 +26,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest from test_jit_save_load import train import paddle @@ -262,6 +263,7 @@ def forward(self, x): return out +@dy2static_unittest class TestPyLayerBase(unittest.TestCase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -512,6 +514,7 @@ def test_pylayer_net_with_no_grad(self): self._run_and_compare(input1, input2) +@dy2static_unittest class PyLayerTrainHelper(unittest.TestCase): def setUp(self): self.place = "gpu" if paddle.is_compiled_with_cuda() else "cpu" @@ -583,6 +586,7 @@ def test_pylayer_net_no_grad(self): ) +@dy2static_unittest class TestPyLayerJitSaveLoad(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_reinforcement_learning.py b/test/dygraph_to_static/test_reinforcement_learning.py index 2a792ebcda7330..ffbd0e315229d7 100644 --- a/test/dygraph_to_static/test_reinforcement_learning.py +++ b/test/dygraph_to_static/test_reinforcement_learning.py @@ -18,7 +18,10 @@ import gym import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle import paddle.nn.functional as F @@ -203,6 +206,7 @@ def finish_episode(): return np.array(loss_data) +@dy2static_unittest class TestDeclarative(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index a99999c4e74475..cb57ce234b2639 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import test_with_new_ir +from dygraph_to_static_util import dy2static_unittest, test_with_new_ir from predictor_utils import PredictorTools import paddle @@ -386,6 +386,7 @@ def predict_analysis_inference(self, data): return out +@dy2static_unittest class TestResnet(unittest.TestCase): def setUp(self): self.resnet_helper = ResNetHelper() diff --git a/test/dygraph_to_static/test_resnet_amp.py b/test/dygraph_to_static/test_resnet_amp.py index 60a30db707be47..0255c0c00db3b5 100644 --- a/test/dygraph_to_static/test_resnet_amp.py +++ b/test/dygraph_to_static/test_resnet_amp.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -111,6 +114,7 @@ def train(to_static, build_strategy=None): return total_loss.numpy() +@dy2static_unittest class TestResnet(unittest.TestCase): def train(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_resnet_pure_fp16.py b/test/dygraph_to_static/test_resnet_pure_fp16.py index 1eb6a8ac9b3a5a..771f9033f99d73 100644 --- a/test/dygraph_to_static/test_resnet_pure_fp16.py +++ b/test/dygraph_to_static/test_resnet_pure_fp16.py @@ -16,7 +16,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_resnet import SEED, ResNet, optimizer_setting import paddle @@ -112,6 +115,7 @@ def train(to_static, build_strategy=None): return loss_data +@dy2static_unittest class TestResnet(unittest.TestCase): def train(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index cf941effd2c288..0f5d804427ca67 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -19,7 +19,7 @@ import unittest import numpy as np -from dygraph_to_static_util import test_with_new_ir +from dygraph_to_static_util import dy2static_unittest, test_with_new_ir from predictor_utils import PredictorTools import paddle @@ -242,6 +242,7 @@ def __len__(self): return len(self.img) +@dy2static_unittest class TestResnet(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_return.py b/test/dygraph_to_static/test_return.py index 41c622e9ed03ab..0cd14b94267cd5 100644 --- a/test/dygraph_to_static/test_return.py +++ b/test/dygraph_to_static/test_return.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest from ifelse_simple_func import dyfunc_with_if_else import paddle @@ -264,6 +264,7 @@ def func(): return func() +@dy2static_unittest class TestReturnBase(unittest.TestCase): def setUp(self): self.input = np.ones(1).astype('int32') diff --git a/test/dygraph_to_static/test_rollback.py b/test/dygraph_to_static/test_rollback.py index 0efb2147f20761..7ee3456747b513 100644 --- a/test/dygraph_to_static/test_rollback.py +++ b/test/dygraph_to_static/test_rollback.py @@ -71,6 +71,7 @@ def foo(x, flag=False): return out +@dy2static_unittest class TestRollBackPlainFunction(unittest.TestCase): def setUp(self): paddle.set_device("cpu") diff --git a/test/dygraph_to_static/test_save_inference_model.py b/test/dygraph_to_static/test_save_inference_model.py index c6a01d38e7d869..468541cfde39eb 100644 --- a/test/dygraph_to_static/test_save_inference_model.py +++ b/test/dygraph_to_static/test_save_inference_model.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -73,6 +77,7 @@ def forward(self, x): return loss, out +@dy2static_unittest class TestDyToStaticSaveInferenceModel(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -223,6 +228,7 @@ def load_and_run_inference( return np.array(results[0]) +@dy2static_unittest class TestPartialProgramRaiseError(unittest.TestCase): @ast_only_test @test_and_compare_with_new_ir(False) diff --git a/test/dygraph_to_static/test_save_load.py b/test/dygraph_to_static/test_save_load.py index 1c7b34435d7ac5..92965aea2ccc2d 100644 --- a/test/dygraph_to_static/test_save_load.py +++ b/test/dygraph_to_static/test_save_load.py @@ -17,7 +17,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_fetch_feed import Linear import paddle @@ -55,6 +59,7 @@ def forward_post_hook_for_prim_net(layer, input, output): return output * 2 +@dy2static_unittest class TestDyToStaticSaveLoad(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_se_resnet.py b/test/dygraph_to_static/test_se_resnet.py index c12990b53659d8..3ef1e62bf1cdab 100644 --- a/test/dygraph_to_static/test_se_resnet.py +++ b/test/dygraph_to_static/test_se_resnet.py @@ -20,7 +20,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) from predictor_utils import PredictorTools import paddle @@ -29,6 +33,7 @@ from paddle.jit.api import to_static from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.nn import BatchNorm, Linear +from paddle.static import InputSpec SEED = 2020 np.random.seed(SEED) @@ -346,6 +351,7 @@ def forward(self, inputs, label): return out, avg_loss, acc_top1, acc_top5 +@dy2static_unittest class TestSeResnet(unittest.TestCase): def setUp(self): self.train_reader = paddle.batch( @@ -368,6 +374,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @test_and_compare_with_new_ir(True) def train(self, train_reader, to_static): paddle.jit.enable_to_static(to_static) @@ -450,9 +457,15 @@ def train(self, train_reader, to_static): paddle.jit.save( se_resnext, self.model_save_prefix, - [img, label], output_spec=[pred], - input_names_after_prune=[img.name], + input_names_after_prune=['x'], + input_spec=[ + InputSpec( + shape=[None, 3, 224, 224], name='x' + ), + InputSpec(shape=[None, 1], name='y'), + ], + clip_extra=False, ) else: paddle.save( @@ -483,6 +496,7 @@ def predict_dygraph(self, data): return pred_res.numpy() + @test_and_compare_with_new_ir(True) def predict_static(self, data): paddle.enable_static() exe = base.Executor(place) diff --git a/test/dygraph_to_static/test_sentiment.py b/test/dygraph_to_static/test_sentiment.py index 22bb980cd437f9..60d3678a5a72b0 100644 --- a/test/dygraph_to_static/test_sentiment.py +++ b/test/dygraph_to_static/test_sentiment.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from test_lac import DynamicGRU import paddle @@ -369,6 +372,7 @@ def train(args, to_static): return loss_data +@dy2static_unittest class TestSentiment(unittest.TestCase): def setUp(self): self.args = Args() diff --git a/test/dygraph_to_static/test_seq2seq.py b/test/dygraph_to_static/test_seq2seq.py index 85de170c3f06c6..b97752d4c57cbf 100644 --- a/test/dygraph_to_static/test_seq2seq.py +++ b/test/dygraph_to_static/test_seq2seq.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest from seq2seq_dygraph_model import AttentionModel, BaseModel from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter @@ -174,6 +175,7 @@ def infer(args, attn_model=False): return outputs.numpy() +@dy2static_unittest class TestSeq2seq(unittest.TestCase): def setUp(self): self.args = Seq2SeqModelHyperParams diff --git a/test/dygraph_to_static/test_set_dynamic_shape.py b/test/dygraph_to_static/test_set_dynamic_shape.py new file mode 100644 index 00000000000000..9ad832aac19ad1 --- /dev/null +++ b/test/dygraph_to_static/test_set_dynamic_shape.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + + +class TestSetDynamicShape(unittest.TestCase): + def test_start(self): + def dygraph_func(loop_number): + mask = paddle.randn([2, 2]) + paddle.jit.dy2static.utils_helper.set_dynamic_shape(mask, [-1, 2]) + n = paddle.randn([1, 2]) + for i in range(loop_number): + mask = paddle.concat([mask, n], axis=0) + if mask.shape[0] == 5: + break + return mask + + loop_num = paddle.to_tensor(10) + expected_shape = dygraph_func(loop_num).shape + actual_shape = paddle.jit.to_static(dygraph_func)(loop_num).shape + self.assertEqual(expected_shape, actual_shape) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/dygraph_to_static/test_simnet.py b/test/dygraph_to_static/test_simnet.py index 7d6cad6d033819..90dce27f87eef2 100644 --- a/test/dygraph_to_static/test_simnet.py +++ b/test/dygraph_to_static/test_simnet.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from simnet_dygraph_model import BOW, HingeLoss import paddle @@ -176,8 +179,9 @@ def train(conf_dict, to_static): return losses +@dy2static_unittest class TestSimnet(unittest.TestCase): - @test_and_compare_with_new_ir(True) + @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): if base.is_compiled_with_cuda(): base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_simnet_v2.py b/test/dygraph_to_static/test_simnet_v2.py index a54cfe14dcbf83..16fccfd731be0b 100644 --- a/test/dygraph_to_static/test_simnet_v2.py +++ b/test/dygraph_to_static/test_simnet_v2.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from simnet_dygraph_model_v2 import BOW, HingeLoss import paddle @@ -176,8 +179,9 @@ def train(conf_dict, to_static): return losses +@dy2static_unittest class TestSimnet(unittest.TestCase): - @test_and_compare_with_new_ir(True) + @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): if paddle.is_compiled_with_cuda(): paddle.base.set_flags({"FLAGS_cudnn_deterministic": True}) diff --git a/test/dygraph_to_static/test_slice.py b/test/dygraph_to_static/test_slice.py index e66080a2c687fa..3bd4c5f8a2c837 100644 --- a/test/dygraph_to_static/test_slice.py +++ b/test/dygraph_to_static/test_slice.py @@ -17,7 +17,7 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle.static import InputSpec @@ -108,6 +108,7 @@ def forward(self, x): return x +@dy2static_unittest class TestSliceWithoutControlFlow(unittest.TestCase): def setUp(self): self.init_input() @@ -169,6 +170,7 @@ def init_dygraph_func(self): self.dygraph_func = test_set_value +@dy2static_unittest class TestSetValueWithLayerAndSave(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -189,6 +191,7 @@ def test_set_value_with_save(self): ) +@dy2static_unittest class TestSliceSupplementSpecialCase(unittest.TestCase): # unittest for slice index which abs(step)>0. eg: x[::2] def test_static_slice_step(self): @@ -232,6 +235,7 @@ def func(inps): ) +@dy2static_unittest class TestPaddleStridedSlice(unittest.TestCase): def test_compare_paddle_strided_slice_with_numpy(self): paddle.disable_static() @@ -293,6 +297,7 @@ def slice_zero_shape_tensor(x): return y +@dy2static_unittest class TestSliceZeroShapeTensor(unittest.TestCase): def test_slice(self): paddle.disable_static() diff --git a/test/dygraph_to_static/test_spec_names.py b/test/dygraph_to_static/test_spec_names.py index 86fe69c507631c..72ffdc845134a8 100644 --- a/test/dygraph_to_static/test_spec_names.py +++ b/test/dygraph_to_static/test_spec_names.py @@ -14,8 +14,9 @@ import unittest -from dygraph_to_static_util import ( - enable_fallback_guard, +from dygraph_to_static_utils_new import ( + Dy2StTestBase, + ast_only_test, test_and_compare_with_new_ir, ) @@ -40,7 +41,7 @@ def forward(self, x, y, m, n): return paddle.sum(out) -class TestArgsSpecName(unittest.TestCase): +class TestArgsSpecName(Dy2StTestBase): def read_from_dataset(self): self.x = paddle.randn([4, 2, 8]) self.y = paddle.randn([4, 2, 8]) @@ -48,6 +49,7 @@ def read_from_dataset(self): self.n = paddle.randn([4, 2, 8]) @test_and_compare_with_new_ir(False) + @ast_only_test def test_spec_name_hash(self): net = Net() net = paddle.jit.to_static(net) @@ -90,5 +92,4 @@ def run_test(self, net, inputs, trace_count, mode): if __name__ == '__main__': - with enable_fallback_guard("False"): - unittest.main() + unittest.main() diff --git a/test/dygraph_to_static/test_tensor_hook.py b/test/dygraph_to_static/test_tensor_hook.py index fc53fefc95ae64..06b1b288ad8993 100644 --- a/test/dygraph_to_static/test_tensor_hook.py +++ b/test/dygraph_to_static/test_tensor_hook.py @@ -15,12 +15,14 @@ import unittest import numpy as np +from dygraph_to_static_util import dy2static_unittest import paddle from paddle import nn from paddle.jit import to_static +@dy2static_unittest class TestStaticAnalysis(unittest.TestCase): def test_hook_for_different_parameter(self): def f(x): diff --git a/test/dygraph_to_static/test_tensor_methods.py b/test/dygraph_to_static/test_tensor_methods.py index 6e1ae1a3ffc0e2..65981d65825a4c 100644 --- a/test/dygraph_to_static/test_tensor_methods.py +++ b/test/dygraph_to_static/test_tensor_methods.py @@ -15,7 +15,11 @@ import unittest import numpy as np -from dygraph_to_static_util import ast_only_test, test_and_compare_with_new_ir +from dygraph_to_static_util import ( + ast_only_test, + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle @@ -27,6 +31,7 @@ def tensor_clone(x): return y +@dy2static_unittest class TestTensorClone(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -48,6 +53,7 @@ def tensor_numpy(x): return x +@dy2static_unittest class TestTensorDygraphOnlyMethodError(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -71,6 +77,7 @@ def tensor_item(x): return y.item() +@dy2static_unittest class TestTensorItem(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -95,6 +102,7 @@ def tensor_size(x): return y +@dy2static_unittest class TestTensorSize(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) @@ -120,6 +128,7 @@ def true_div(x, y): return z +@dy2static_unittest class TestTrueDiv(unittest.TestCase): def _run(self, to_static): paddle.jit.enable_to_static(to_static) diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index ad85daf7b0f78b..d8c13cff351931 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -15,9 +15,9 @@ import unittest import numpy as np -from dygraph_to_static_util import ( +from dygraph_to_static_utils_new import ( + Dy2StTestBase, ast_only_test, - dy2static_unittest, test_and_compare_with_new_ir, ) @@ -235,8 +235,7 @@ def dyfunc_dict_assign_shape(): # 1. Basic tests without control flow -@dy2static_unittest -class TestTensorShapeBasic(unittest.TestCase): +class TestTensorShapeBasic(Dy2StTestBase): def setUp(self): self.input = np.ones(5).astype("int32") self.place = ( @@ -495,7 +494,7 @@ def _set_expected_op_num(self): # 5. Test op num for negative dim -class TestOpNumBasicWithTensorShape(unittest.TestCase): +class TestOpNumBasicWithTensorShape(Dy2StTestBase): def setUp(self): self._set_input_spec() self._set_test_func() @@ -617,7 +616,7 @@ def dyfunc_with_static_convert_var_shape(x): return res -class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase): +class TestFindStatiConvertVarShapeSuffixVar(Dy2StTestBase): @ast_only_test def test(self): x_spec = paddle.static.InputSpec(shape=[None, 10]) diff --git a/test/dygraph_to_static/test_to_tensor.py b/test/dygraph_to_static/test_to_tensor.py index ee33d56187efa6..b211e09254eded 100644 --- a/test/dygraph_to_static/test_to_tensor.py +++ b/test/dygraph_to_static/test_to_tensor.py @@ -96,6 +96,10 @@ def case8(x): return a +def case_to_tensor_default_dtype(): + return paddle.to_tensor(1) + + @dy2static_unittest class TestToTensorReturnVal(unittest.TestCase): def test_to_tensor_badreturn(self): @@ -150,6 +154,13 @@ def test_to_tensor_badreturn(self): self.assertTrue(a.stop_gradient == b.stop_gradient) self.assertTrue(a.place._equals(b.place)) + def test_to_tensor_default_dtype(self): + a = paddle.jit.to_static(case_to_tensor_default_dtype)() + b = case_to_tensor_default_dtype() + self.assertTrue(a.dtype == b.dtype) + self.assertTrue(a.stop_gradient == b.stop_gradient) + self.assertTrue(a.place._equals(b.place)) + def test_to_tensor_err_log(self): paddle.disable_static() x = paddle.to_tensor([3]) @@ -162,6 +173,7 @@ def test_to_tensor_err_log(self): ) +@dy2static_unittest class TestStatic(unittest.TestCase): def test_static(self): paddle.enable_static() diff --git a/test/dygraph_to_static/test_transformer.py b/test/dygraph_to_static/test_transformer.py index 073535371ccde3..29dda3916f3ab9 100644 --- a/test/dygraph_to_static/test_transformer.py +++ b/test/dygraph_to_static/test_transformer.py @@ -20,7 +20,10 @@ import numpy as np import transformer_util as util -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from transformer_dygraph_model import ( CrossEntropyCriterion, Transformer, @@ -527,6 +530,7 @@ def predict_static(args, batch_generator): return seq_ids, seq_scores +@dy2static_unittest class TestTransformer(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/test/dygraph_to_static/test_tsm.py b/test/dygraph_to_static/test_tsm.py index e68406bd4c9ab2..2cef9e7df4dedd 100644 --- a/test/dygraph_to_static/test_tsm.py +++ b/test/dygraph_to_static/test_tsm.py @@ -19,7 +19,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from tsm_config_utils import merge_configs, parse_config, print_configs import paddle @@ -384,6 +387,7 @@ def train(args, fake_data_reader, to_static): return ret +@dy2static_unittest class TestTsm(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): diff --git a/test/dygraph_to_static/test_typehint.py b/test/dygraph_to_static/test_typehint.py index b37a3539e22543..563db1d7a1df04 100644 --- a/test/dygraph_to_static/test_typehint.py +++ b/test/dygraph_to_static/test_typehint.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -33,6 +36,7 @@ def function(x: A) -> A: return 2 * x +@dy2static_unittest class TestTransformWhileLoop(unittest.TestCase): def setUp(self): self.place = ( diff --git a/test/dygraph_to_static/test_unuseful_inputs.py b/test/dygraph_to_static/test_unuseful_inputs.py index 603ffe9eba12dc..8f83f015db4315 100644 --- a/test/dygraph_to_static/test_unuseful_inputs.py +++ b/test/dygraph_to_static/test_unuseful_inputs.py @@ -15,7 +15,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import nn @@ -62,6 +65,7 @@ def forward(self, x): return val +@dy2static_unittest class TestDuplicateOutput(unittest.TestCase): """ TestCase for the transformation from control flow `if/else` diff --git a/test/dygraph_to_static/test_utils.py b/test/dygraph_to_static/test_utils.py index 3361a866feb540..180078c1448295 100644 --- a/test/dygraph_to_static/test_utils.py +++ b/test/dygraph_to_static/test_utils.py @@ -15,9 +15,12 @@ import types import unittest +from dygraph_to_static_util import dy2static_unittest + from paddle.jit.dy2static.utils import index_in_list, is_paddle_func +@dy2static_unittest class TestIndexInList(unittest.TestCase): def test_index_in_list(self): list_to_test = [1, 2, 3, 4, 5] @@ -49,6 +52,7 @@ def dyfunc_assign(input): y = n +@dy2static_unittest class TestIsPaddle(unittest.TestCase): def fake_module(self): return types.ModuleType('paddlenlp') diff --git a/test/dygraph_to_static/test_variable_trans_func.py b/test/dygraph_to_static/test_variable_trans_func.py index f2395fa517793d..0ca73fbf9dd755 100644 --- a/test/dygraph_to_static/test_variable_trans_func.py +++ b/test/dygraph_to_static/test_variable_trans_func.py @@ -14,10 +14,13 @@ import unittest +from dygraph_to_static_util import dy2static_unittest + from paddle.jit.dy2static.utils import ast_to_source_code from paddle.jit.dy2static.variable_trans_func import create_fill_constant_node +@dy2static_unittest class TestVariableTransFunc(unittest.TestCase): def test_create_fill_constant_node(self): node = create_fill_constant_node("a", 1.0) diff --git a/test/dygraph_to_static/test_word2vec.py b/test/dygraph_to_static/test_word2vec.py index 85edea2093d82f..0f16f5b2a9d23f 100644 --- a/test/dygraph_to_static/test_word2vec.py +++ b/test/dygraph_to_static/test_word2vec.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) import paddle from paddle import base @@ -318,6 +321,7 @@ def train(to_static): return np.array(ret) +@dy2static_unittest class TestWord2Vec(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): diff --git a/test/dygraph_to_static/test_yolov3.py b/test/dygraph_to_static/test_yolov3.py index 3f31b666c7f31d..12830ca7bce557 100644 --- a/test/dygraph_to_static/test_yolov3.py +++ b/test/dygraph_to_static/test_yolov3.py @@ -17,7 +17,10 @@ import unittest import numpy as np -from dygraph_to_static_util import test_and_compare_with_new_ir +from dygraph_to_static_util import ( + dy2static_unittest, + test_and_compare_with_new_ir, +) from yolov3 import YOLOv3, cfg import paddle @@ -165,6 +168,7 @@ def train(to_static): return np.array(ret) +@dy2static_unittest class TestYolov3(unittest.TestCase): @test_and_compare_with_new_ir(False) def test_dygraph_static_same_loss(self): diff --git a/test/ir/inference/CMakeLists.txt b/test/ir/inference/CMakeLists.txt index 5c6714e698444d..fa01fe99a9e3ed 100755 --- a/test/ir/inference/CMakeLists.txt +++ b/test/ir/inference/CMakeLists.txt @@ -181,10 +181,13 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(test_trt_inference_predictor PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_inference_fp16_io PROPERTIES TIMEOUT 300) set_tests_properties(test_trt_optimization_level PROPERTIES TIMEOUT 300) - set_tests_properties(test_trt_explicit_quantization_resnet PROPERTIES TIMEOUT - 300) - set_tests_properties(test_trt_explicit_quantization_mobilenet - PROPERTIES TIMEOUT 300) + set_tests_properties(test_trt_ops_fp32_mix_precision PROPERTIES TIMEOUT 300) + if(NOT WIN32) + set_tests_properties(test_trt_explicit_quantization_resnet + PROPERTIES TIMEOUT 300) + set_tests_properties(test_trt_explicit_quantization_mobilenet + PROPERTIES TIMEOUT 300) + endif() if(WITH_MKLDNN) set_tests_properties(test_save_optimized_model_pass PROPERTIES TIMEOUT 300) endif() diff --git a/test/ir/inference/test_trt_convert_share_data.py b/test/ir/inference/test_trt_convert_share_data.py new file mode 100644 index 00000000000000..168ef72b6e590b --- /dev/null +++ b/test/ir/inference/test_trt_convert_share_data.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertShareDataTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + compile_version = paddle_infer.get_trt_compile_version() + runtime_version = paddle_infer.get_trt_runtime_version() + if ( + compile_version[0] * 1000 + + compile_version[1] * 100 + + compile_version[2] * 10 + < 8400 + ): + return False + if ( + runtime_version[0] * 1000 + + runtime_version[1] * 100 + + runtime_version[2] * 10 + < 8400 + ): + return False + return True + + def sample_program_configs(self): + def generate_input(type): + if self.dims == 1: + return np.ones([1]).astype(type) + else: + return np.ones([1, 3, 64, 64]).astype(type) + + for dims in [1, 4]: + self.dims = dims + for dtype in [ + np.int32, + np.float32, + np.int64, + ]: + self.has_bool_dtype = dtype == np.bool_ + ops_config = [ + { + "op_type": "share_data", + "op_inputs": {"X": ["input_data"]}, + "op_outputs": {"Out": ["output_data0"]}, + "op_attrs": {}, + }, + { + "op_type": "share_data", + "op_inputs": {"X": ["output_data0"]}, + "op_outputs": {"Out": ["output_data1"]}, + "op_attrs": {}, + }, + ] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, dtype) + ) + }, + outputs=["output_data1"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [1]} + self.dynamic_shape.opt_input_shape = {"input_data": [1]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 64, 64] + } + self.dynamic_shape.max_input_shape = { + "input_data": [1, 3, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape and self.dims == 1: + return 0, 4 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + program_config.set_input_type(np.float32) + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + program_config.set_input_type(np.float16) + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-2 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + program_config.set_input_type(np.float32) + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + program_config.set_input_type(np.float16) + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/inference/test_trt_ops_fp32_mix_precision.py b/test/ir/inference/test_trt_ops_fp32_mix_precision.py new file mode 100644 index 00000000000000..c2fcb2255c95c6 --- /dev/null +++ b/test/ir/inference/test_trt_ops_fp32_mix_precision.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer +from paddle.inference import InternalUtils + + +class TestTrtFp32MixPrecision(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_conv2d_input(): + return np.ones([1, 3, 64, 64]).astype(np.float32) + + def generate_conv2d_weight(): + return np.ones([9, 3, 3, 3]).astype(np.float32) + + def generate_elementwise_input(op_type): + # elementwise_floordiv is integer only + if op_type == "elementwise_mod": + return np.random.uniform( + low=0.1, high=1.0, size=[33, 10] + ).astype(np.float32) + else: + return np.random.random([33, 10]).astype(np.float32) + + def generate_elementwise_weight(op_type): + if op_type == "elementwise_mod": + return np.random.uniform( + low=0.1, high=1.0, size=[33, 1] + ).astype(np.float32) + else: + return np.random.randn(33, 1).astype(np.float32) + + attrs = [ + { + "data_fromat": 'NCHW', + "dilations": [1, 2], + "padding_algorithm": 'EXPLICIT', + "groups": 1, + "paddings": [0, 3], + "strides": [2, 2], + }, + {"axis": -1}, + { + "trans_x": False, + "trans_y": False, + }, + ] + for op_type in [ + "elementwise_add", + "elementwise_mul", + "elementwise_sub", + "elementwise_div", + "elementwise_pow", + "elementwise_min", + "elementwise_max", + "elementwise_mod", + ]: + ops_config = [ + { + "op_type": "conv2d", + "op_inputs": { + "Input": ["conv2d_input"], + "Filter": ["conv2d_weight"], + }, + "op_outputs": {"Output": ["conv_output_data"]}, + "op_attrs": attrs[0], + }, + { + "op_type": op_type, + "op_inputs": { + "X": ["elementwise_input"], + "Y": ["elementwise_weight"], + }, + "op_outputs": {"Out": ["elementwise_output_data"]}, + "op_attrs": attrs[1], + "outputs_dtype": {"output_data": np.float32}, + }, + { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["conv_output_data"], + "Y": ["elementwise_output_data"], + }, + "op_outputs": {"Out": ["matmul_v2_output_data"]}, + "op_attrs": attrs[2], + }, + ] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "conv2d_weight": TensorConfig( + data_gen=partial(generate_conv2d_weight) + ), + "elementwise_weight": TensorConfig( + data_gen=partial(generate_elementwise_weight, op_type) + ), + }, + inputs={ + "conv2d_input": TensorConfig( + data_gen=partial(generate_conv2d_input) + ), + "elementwise_input": TensorConfig( + data_gen=partial(generate_elementwise_input, op_type) + ), + }, + outputs=["matmul_v2_output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "conv2d_input": [1, 3, 64, 64], + "elementwise_input": [33, 10], + } + self.dynamic_shape.max_input_shape = { + "conv2d_input": [1, 3, 64, 64], + "elementwise_input": [33, 10], + } + self.dynamic_shape.opt_input_shape = { + "conv2d_input": [1, 3, 64, 64], + "elementwise_input": [33, 10], + } + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 3 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Half + program_config.set_input_type(np.float16) + config = self.create_inference_config() + InternalUtils.disable_tensorrt_half_ops( + config, + { + "conv_output_data", + "elementwise_output_data", + "matmul_v2_output_data", + }, + ) + yield config, generate_trt_nodes_num(attrs, True), (1e-3, 1e-3) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index a2a17feb1275fa..032f451bda842e 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -35,7 +35,71 @@ def test_op(self): x = paddle.to_tensor([2, 3, 4], 'float64') y = paddle.cast(x, 'uint8') - _ = pir.translate_to_new_ir(main_program.desc) + _, mappings = pir.translate_to_new_ir_with_param_map(main_program.desc) + assert len(str(mappings)) > 0, "no mapping found" + + +class TestCondWithInplace(unittest.TestCase): + def test_op(self): + def cond_with_inplace(): + x = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + y = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + running_mean = paddle.to_tensor([0], dtype="float32") + running_variance = paddle.to_tensor([1], dtype="float32") + weight = paddle.to_tensor([2], dtype="float32") + bias = paddle.to_tensor([1], dtype="float32") + if x > y: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + + legacy_program = paddle.jit.to_static( + cond_with_inplace, + input_spec=[], + ) + + l = pir.translate_to_new_ir(legacy_program.main_program.desc) + assert l is not None + + def test_nested_op(self): + def cond_with_inplace(): + x = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + y = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + z = paddle.ones(shape=[2, 1, 2, 3], dtype="float32") + running_mean = paddle.to_tensor([0], dtype="float32") + running_variance = paddle.to_tensor([1], dtype="float32") + weight = paddle.to_tensor([2], dtype="float32") + bias = paddle.to_tensor([1], dtype="float32") + if x > y: + if y > z: + z = paddle.nn.functional.batch_norm( + z, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + else: + if y > z: + z = paddle.nn.functional.batch_norm( + z, running_mean, running_variance, weight, bias + ) + else: + y = paddle.nn.functional.batch_norm( + x, running_mean, running_variance, weight, bias + ) + + legacy_program = paddle.jit.to_static( + cond_with_inplace, + input_spec=[], + ) + + l = pir.translate_to_new_ir(legacy_program.main_program.desc) + assert l is not None class TestElementwiseOpTranscriber(unittest.TestCase): @@ -100,6 +164,27 @@ def test_elementwise_with_y_grad(self): atol=1e-6, ) + def test_add_inplace(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + exe = paddle.static.Executor(place) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x = paddle.ones(shape=(100, 2, 3), dtype='float32') + y = paddle.ones(shape=(100, 2, 3), dtype='float32') + + helper = LayerHelper('elementwise_add') + helper.append_op( + type="elementwise_add", + inputs={"X": x, "Y": y}, + outputs={"Out": y}, + attrs={"axis": -1}, + ) + _ = pir.translate_to_new_ir(main_program.desc) + class TestEmbeddingOpTranscriber(unittest.TestCase): def test_op(self): @@ -408,6 +493,30 @@ def test_grad(self): self.assertTrue((ret[0][6:0:-4] == 0).all()) +class TestShareBufferOpTranscriber(unittest.TestCase): + def test_program(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x = paddle.ones(shape=(100, 2, 3), dtype='float32') + y = paddle.ones(shape=(100, 2, 3), dtype='float32') + + helper = LayerHelper('share_buffer') + helper.append_op( + type="share_buffer", + inputs={"X": x}, + outputs={"Out": y, "XOut": x}, + ) + l = pir.translate_to_new_ir(main_program.desc) + assert ( + l.global_block().ops[2].name() == "pd_op.share_data" + ), "share_buffer should be translated to share_data" + + class TestCheckUnregisteredOp(unittest.TestCase): def test_program(self): main_program = paddle.static.Program() diff --git a/test/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py b/test/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py index 4c0b5d5689885b..ce96268f788b4b 100644 --- a/test/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py +++ b/test/ir/test_ir_embedding_eltwise_layernorm_fuse_pass.py @@ -121,14 +121,13 @@ def setUp(self): self.num_fused_ops = 2 def test_check_output(self): - use_gpu_set = [True] if not core.is_compiled_with_cuda(): return self.pass_attrs = { "embedding_eltwise_layernorm_fuse_pass": {"use_gpu": True} } place = base.CUDAPlace(0) - self.check_output_with_place(place, startup_on_cpu=True) + self.check_output_with_place(place) if __name__ == "__main__": diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index eae8a200212dfe..77ca5512d2b4fa 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -1967,7 +1967,7 @@ def check_output_with_place( only_check_prim=False, inplace_atol=None, check_cinn=False, - check_new_ir=False, + check_pir=False, ): core._set_prim_all_enabled(False) core.set_prim_eager_enabled(False) @@ -2538,7 +2538,7 @@ def _is_skip_name(self, name): dygraph_checker.check() dygraph_dygraph_outs = dygraph_checker.outputs - if check_new_ir: + if check_pir: if ( type(place) is paddle.base.libpaddle.CPUPlace or type(place) is paddle.base.libpaddle.CUDAPlace @@ -2657,7 +2657,7 @@ def check_output( inplace_atol=None, check_cinn=False, only_check_prim=False, - check_new_ir=False, + check_pir=False, ): self.__class__.op_type = self.op_type if self.is_mkldnn_op(): @@ -2683,7 +2683,7 @@ def check_output( only_check_prim=only_check_prim, inplace_atol=inplace_atol, check_cinn=check_cinn, - check_new_ir=check_new_ir, + check_pir=check_pir, ) if not res and only_check_prim: continue @@ -2700,7 +2700,7 @@ def check_output( self.check_compile_vs_runtime(fetch_list, outs) def check_output_customized( - self, checker, custom_place=None, check_new_ir=False + self, checker, custom_place=None, check_pir=False ): self.__class__.op_type = self.op_type places = self._get_places() @@ -2711,7 +2711,7 @@ def check_output_customized( outs = [np.array(out) for out in outs] outs.sort(key=len) checker(outs) - if check_new_ir: + if check_pir: with paddle.pir_utils.IrGuard(): outs_p = self._calc_new_ir_output(place) outs_p = [outs_p[out] for out in outs_p] @@ -2719,18 +2719,18 @@ def check_output_customized( checker(outs_p[0]) def check_output_with_place_customized( - self, checker, place, check_new_ir=False + self, checker, place, check_pir=False ): outs = self.calc_output(place) outs = [np.array(out) for out in outs] outs.sort(key=len) checker(outs) - if check_new_ir: + if check_pir: with paddle.pir_utils.IrGuard(): outs_p = self._calc_new_ir_output(place) - outs_p = [outs_p[out] for out in outs_p] + outs_p = [outs_p[out][0] for out in outs_p] outs_p.sort(key=len) - checker(outs_p[0]) + checker(outs_p) def _assert_is_close( self, @@ -2867,7 +2867,7 @@ def check_grad( only_check_prim=False, atol=1e-5, check_cinn=False, - check_new_ir=False, + check_pir=False, ): if hasattr(self, "use_custom_device") and self.use_custom_device: check_dygraph = False @@ -2891,7 +2891,7 @@ def check_grad( only_check_prim=only_check_prim, atol=atol, check_cinn=check_cinn, - check_new_ir=check_new_ir, + check_pir=check_pir, ) def check_grad_with_place( @@ -2912,7 +2912,7 @@ def check_grad_with_place( numeric_place=None, atol=1e-5, check_cinn=False, - check_new_ir=False, + check_pir=False, ): if hasattr(self, "use_custom_device") and self.use_custom_device: check_dygraph = False @@ -3126,7 +3126,7 @@ def check_grad_with_place( ) # get pir gradient - if check_new_ir: + if check_pir: if ( type(place) is paddle.base.libpaddle.CPUPlace or type(place) is paddle.base.libpaddle.CUDAPlace diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f4f56d049524a3..2fbc85815ff58f 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -27,6 +27,7 @@ from paddle import base, static from paddle.base import Program, core, program_guard from paddle.base.layer_helper import LayerHelper +from paddle.pir_utils import test_with_pir_api @contextmanager @@ -127,10 +128,16 @@ def setUp(self): self.convert_input_output() def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) def init_dtype(self): self.dtype = np.float32 @@ -174,12 +181,10 @@ def setUp(self): self.convert_input_output() def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad( - ['X'], 'Out', max_relative_error=0.006, check_new_ir=True - ) + self.check_grad(['X'], 'Out', max_relative_error=0.006, check_pir=True) def init_dtype(self): self.dtype = np.complex64 @@ -249,10 +254,10 @@ def setUp(self): self.convert_input_output() def test_check_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestExpm1_Complex64(TestExpm1): @@ -260,10 +265,10 @@ def init_dtype(self): self.dtype = np.complex64 def test_check_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestExpm1_Complex128(TestExpm1_Complex64): @@ -383,10 +388,19 @@ def init_dtype(self): def if_enable_cinn(self): pass + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.01, check_prim=True) + self.check_grad( + ['X'], + 'Out', + max_relative_error=0.01, + check_prim=True, + check_pir=True, + ) class TestSigmoid_Complex64(TestSigmoid): @@ -394,7 +408,7 @@ def init_dtype(self): self.dtype = np.complex64 def test_check_output(self): - self.check_output(check_prim=False) + self.check_output(check_prim=False, check_pir=True) def test_check_grad(self): self.check_grad( @@ -402,6 +416,7 @@ def test_check_grad(self): 'Out', max_relative_error=0.006, check_prim=False, + check_pir=True, ) @@ -410,11 +425,7 @@ def init_dtype(self): self.dtype = np.complex128 def test_check_grad(self): - self.check_grad( - ['X'], - 'Out', - check_prim=False, - ) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) class TestSigmoid_ZeroDim(TestSigmoid): @@ -455,12 +466,13 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - # elementwise_pow doesn't support bfloat16, skip check_prim here. - self.check_output_with_place(place, check_prim=True) + self.check_output_with_place(place, check_prim=True, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, check_pir=True + ) ''' @@ -501,14 +513,25 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_new_ir=True) + if self.dtype == np.complex64 or self.dtype == np.complex128: + self.check_output(check_pir=True) + else: + self.check_output( + check_prim=True, check_pir=True, check_prim_pir=True + ) def test_check_grad(self): # TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: - self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) else: - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestSilu_ZeroDim(TestSilu): @@ -694,7 +717,7 @@ def setUp(self): self.convert_input_output() def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -706,14 +729,14 @@ def test_check_grad(self): 'Out', check_prim=False, check_prim_pir=False, - check_new_ir=True, + check_pir=True, ) else: self.check_grad( ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -1439,7 +1462,7 @@ def test_errors(self): class TestSqrt(TestActivation, TestParameter): def setUp(self): self.op_type = "sqrt" - self.prim_op_type = "prim" + self.prim_op_type = "comp" self.python_api = paddle.sqrt self.public_python_api = paddle.sqrt @@ -1461,16 +1484,22 @@ def if_enable_cinn(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) class TestSqrtPrimFp32(TestActivation): def setUp(self): self.op_type = "sqrt" - self.prim_op_type = "prim" + self.prim_op_type = "comp" self.python_api = paddle.sqrt self.public_python_api = paddle.sqrt self.init_dtype() @@ -1486,10 +1515,16 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True, check_prim_pir=True) def init_dtype(self): self.dtype = np.float32 @@ -1510,7 +1545,7 @@ def init_shape(self): class TestSqrtBF16(OpTest): def setUp(self): self.op_type = "sqrt" - self.prim_op_type = "prim" + self.prim_op_type = "comp" self.python_api = paddle.sqrt self.public_python_api = paddle.sqrt self.init_dtype() @@ -1537,12 +1572,17 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True, check_prim_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -1571,12 +1611,20 @@ def test_check_grad(self): if self.dtype == np.float16: return self.check_grad( - ['X'], 'Out', check_dygraph=True, check_prim=True, check_new_ir=True + ['X'], + 'Out', + check_dygraph=True, + check_prim=True, + check_pir=True, + check_prim_pir=True, ) def test_check_output(self): self.check_output( - check_dygraph=True, check_prim=True, check_new_ir=True + check_dygraph=True, + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -1603,12 +1651,20 @@ def test_check_grad(self): if self.dtype == np.float16: return self.check_grad( - ['X'], 'Out', check_dygraph=True, check_prim=True, check_new_ir=True + ['X'], + 'Out', + check_dygraph=True, + check_prim=True, + check_pir=True, + check_prim_pir=True, ) def test_check_output(self): self.check_output( - check_dygraph=True, check_prim=True, check_new_ir=True + check_dygraph=True, + check_prim=True, + check_pir=True, + check_prim_pir=True, ) def init_dtype(self): @@ -1645,9 +1701,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output( - check_prim=True, check_new_ir=True, check_prim_pir=True - ) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -1657,7 +1711,7 @@ def test_check_grad(self): 'Out', max_relative_error=0.0005, check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -1700,12 +1754,12 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestAbs_ZeroDim(TestAbs): @@ -1732,7 +1786,7 @@ def init_shape(self): self.shape = [10, 12] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) # The same reason with TestFloor def test_check_grad(self): @@ -1768,6 +1822,9 @@ def init_shape(self): def if_enable_cinn(self): pass + def test_check_output(self): + self.check_output(check_pir=True) + # the gradient on floor, ceil, round is undefined. # we return zero as gradient, but the numpy return nan # The same reason with TestFloor @@ -1786,6 +1843,7 @@ def test_check_grad_for_prim(self): 'Out', check_prim=True, only_check_prim=True, + check_pir=True, ) @@ -1819,6 +1877,9 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return @@ -1826,10 +1887,14 @@ def test_check_grad(self): if self.dtype == np.complex64 or self.dtype == np.complex128: # Complex64 [GPU]: AssertionError: 0.0057843705 not less than or equal to 0.005 self.check_grad( - ['X'], 'Out', check_prim=False, max_relative_error=0.006 + ['X'], + 'Out', + check_prim=False, + max_relative_error=0.006, + check_pir=True, ) else: - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def if_enable_cinn(self): pass @@ -2016,14 +2081,22 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + @test_with_pir_api + def test_out_name(self): + # inherit from `TestParameter` + super().test_out_name() + + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): if self.dtype == np.float16: return # TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype if self.dtype == np.complex64 or self.dtype == np.complex128: - self.check_grad(['X'], 'Out', check_prim=False) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) else: - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def if_enable_cinn(self): pass @@ -2246,6 +2319,9 @@ def setUp(self): def init_shape(self): self.shape = [10, 12] + def test_check_output(self): + self.check_output(check_pir=True) + def test_check_grad(self): pass @@ -2278,10 +2354,10 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output(check_prim=True, check_pir=True) def if_enable_cinn(self): pass @@ -2521,7 +2597,7 @@ def setUp(self): def test_check_output(self): self.check_output( check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=False, ) @@ -2532,7 +2608,7 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -2567,9 +2643,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output( - check_prim=True, check_new_ir=True, check_prim_pir=False - ) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=False) def test_check_grad(self): if self.dtype == np.float16: @@ -2578,7 +2652,7 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -2605,6 +2679,7 @@ def setUp(self): self.rev_comp_rtol = 1e-8 self.rev_comp_atol = 1e-8 + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -3261,12 +3336,18 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class Test_Log_Op_Fp16(unittest.TestCase): @@ -3592,10 +3673,10 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007, check_pir=True) def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestSquare_ZeroDim(TestSquare): @@ -3627,11 +3708,13 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=0.5) + self.check_grad_with_place( + place, ['X'], 'Out', numeric_grad_delta=0.5, check_pir=True + ) class TestPow(TestActivation): @@ -3657,9 +3740,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -3669,7 +3750,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -3877,6 +3958,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = ref_softplus(x, beta, threshold) self.inputs = {'X': x} self.attrs = {'beta': beta, "threshold": threshold} @@ -3891,6 +3977,19 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestSoftplus_Complex64(TestSoftplus): + def init_dtype(self): + self.dtype = np.complex64 + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.06) + + +class TestSoftplus_Complex128(TestSoftplus): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSoftplus_ZeroDim(TestSoftplus): def init_shape(self): self.shape = [] @@ -4505,7 +4604,7 @@ def create_test_act_fp16_class( check_prim=False, check_prim_pir=False, enable_cinn=False, - check_new_ir=False, + check_pir=False, grad_atol=1e-2, **kwargs ): @@ -4534,7 +4633,7 @@ def test_check_output(self): check_dygraph=check_dygraph, check_prim=check_prim, check_prim_pir=check_prim_pir, - check_new_ir=check_new_ir, + check_pir=check_pir, ) def test_check_grad(self): @@ -4549,7 +4648,7 @@ def test_check_grad(self): check_prim=check_prim, check_prim_pir=check_prim_pir, max_relative_error=grad_atol, - check_new_ir=check_new_ir, + check_pir=check_pir, ) cls_name = "{}_{}".format(parent.__name__, "FP16OP") @@ -4558,10 +4657,16 @@ def test_check_grad(self): create_test_act_fp16_class(TestActivation) -create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestExpFp32_Prim, check_prim=True, enable_cinn=True, check_prim_pir=True +) create_test_act_fp16_class(TestExpm1) -create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) -create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestSigmoid, check_prim=True, enable_cinn=True, check_pir=True +) +create_test_act_fp16_class( + TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True +) create_test_act_fp16_class(TestLogSigmoid) create_test_act_fp16_class( TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True @@ -4570,38 +4675,50 @@ def test_check_grad(self): create_test_act_fp16_class(TestHardShrink) create_test_act_fp16_class(TestSoftshrink) create_test_act_fp16_class( - TestSqrt, check_prim=True, enable_cinn=True, check_new_ir=True + TestSqrt, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class( - TestSqrtComp, check_prim=True, enable_cinn=True, check_new_ir=True + TestSqrtComp, + check_prim=True, + enable_cinn=True, + check_pir=True, + check_prim_pir=True, ) create_test_act_fp16_class( - TestAbs, check_prim=True, enable_cinn=True, check_new_ir=True + TestAbs, check_prim=True, enable_cinn=True, check_pir=True ) -create_test_act_fp16_class(TestCeil, grad_check=False, check_new_ir=True) +create_test_act_fp16_class(TestCeil, grad_check=False, check_pir=True) create_test_act_fp16_class( - TestFloor, check_prim=True, grad_check=False, enable_cinn=True + TestFloor, + check_prim=True, + grad_check=False, + enable_cinn=True, + check_pir=True, ) -create_test_act_fp16_class(TestCos) +create_test_act_fp16_class(TestCos, check_pir=True) create_test_act_fp16_class(TestTan) create_test_act_fp16_class(TestCosh) create_test_act_fp16_class(TestAcos) -create_test_act_fp16_class(TestSin) +create_test_act_fp16_class(TestSin, check_pir=True) create_test_act_fp16_class(TestSinh) create_test_act_fp16_class(TestAsin) create_test_act_fp16_class(TestAtan) create_test_act_fp16_class(TestAcosh) create_test_act_fp16_class(TestAsinh) create_test_act_fp16_class(TestAtanh) -create_test_act_fp16_class(TestRound, grad_check=False) +create_test_act_fp16_class(TestRound, grad_check=False, check_pir=True) create_test_act_fp16_class( - TestRelu, check_prim=True, enable_cinn=True, check_new_ir=True + TestRelu, check_prim=True, enable_cinn=True, check_pir=True ) create_test_act_fp16_class( TestGelu, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, enable_cinn=True, rev_comp_rtol=1e-3, rev_comp_atol=1e-3, @@ -4614,14 +4731,14 @@ def test_check_grad(self): create_test_act_fp16_class(TestELU) create_test_act_fp16_class(TestCELU) create_test_act_fp16_class(TestReciprocal) -create_test_act_fp16_class(TestLog, check_prim=True, check_new_ir=True) +create_test_act_fp16_class(TestLog, check_prim=True, check_pir=True) if core.is_compiled_with_rocm(): create_test_act_fp16_class(TestLog2) else: create_test_act_fp16_class(TestLog2) create_test_act_fp16_class(TestLog10) create_test_act_fp16_class(TestLog1p) -create_test_act_fp16_class(TestSquare) +create_test_act_fp16_class(TestSquare, check_pir=True) create_test_act_fp16_class(TestPow, check_prim=True, check_prim_pir=True) create_test_act_fp16_class(TestPow_API) create_test_act_fp16_class(TestSTanh) @@ -4647,7 +4764,7 @@ def test_check_grad(self): TestRsqrt, check_prim=True, enable_cinn=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -4659,7 +4776,7 @@ def create_test_act_bf16_class( check_dygraph=True, check_prim=False, enable_cinn=False, - check_new_ir=False, + check_pir=False, grad_atol=1e-2, **kwargs ): @@ -4691,7 +4808,7 @@ def test_check_output(self): place, atol=atol, check_prim=check_prim, - check_new_ir=check_new_ir, + check_pir=check_pir, ) def test_check_grad(self): @@ -4703,7 +4820,7 @@ def test_check_grad(self): 'Out', max_relative_error=grad_atol, check_prim=check_prim, - check_new_ir=check_new_ir, + check_pir=check_pir, ) cls_name = "{}_{}".format(parent.__name__, "BF16OP") @@ -4712,37 +4829,45 @@ def test_check_grad(self): create_test_act_bf16_class(TestActivation) -create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True) +create_test_act_bf16_class( + TestExpFp32_Prim, check_prim=True, check_prim_pir=True +) create_test_act_bf16_class(TestExpm1) -create_test_act_bf16_class(TestSigmoid, check_prim=True) -create_test_act_bf16_class(TestSilu, check_prim=True) +create_test_act_bf16_class(TestSigmoid, check_prim=True, check_pir=True) +create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestLogSigmoid) -create_test_act_bf16_class(TestTanh, check_prim=True) +create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True) create_test_act_bf16_class(TestTanhshrink) create_test_act_bf16_class(TestHardShrink) create_test_act_bf16_class(TestSoftshrink) -create_test_act_bf16_class(TestSqrt, check_prim=True, check_new_ir=True) -create_test_act_bf16_class(TestSqrtComp, check_prim=True, check_new_ir=True) -create_test_act_bf16_class(TestAbs, check_prim=True, check_new_ir=True) -create_test_act_bf16_class(TestCeil, grad_check=False, check_new_ir=True) -create_test_act_bf16_class(TestFloor, grad_check=False, check_prim=True) -create_test_act_bf16_class(TestCos) +create_test_act_bf16_class( + TestSqrt, check_prim=True, check_pir=True, check_prim_pir=True +) +create_test_act_bf16_class( + TestSqrtComp, check_prim=True, check_pir=True, check_prim_pir=True +) +create_test_act_bf16_class(TestAbs, check_prim=True, check_pir=True) +create_test_act_bf16_class(TestCeil, grad_check=False, check_pir=True) +create_test_act_bf16_class( + TestFloor, grad_check=False, check_prim=True, check_pir=True +) +create_test_act_bf16_class(TestCos, check_pir=True) create_test_act_bf16_class(TestTan) create_test_act_bf16_class(TestCosh) create_test_act_bf16_class(TestAcos) -create_test_act_bf16_class(TestSin) +create_test_act_bf16_class(TestSin, check_pir=True) create_test_act_bf16_class(TestSinh) create_test_act_bf16_class(TestAsin) create_test_act_bf16_class(TestAtan) create_test_act_bf16_class(TestAcosh) create_test_act_bf16_class(TestAsinh) create_test_act_bf16_class(TestAtanh) -create_test_act_bf16_class(TestRound, grad_check=False) -create_test_act_bf16_class(TestRelu, check_prim=True, check_new_ir=True) +create_test_act_bf16_class(TestRound, grad_check=False, check_pir=True) +create_test_act_bf16_class(TestRelu, check_prim=True, check_pir=True) create_test_act_bf16_class( TestGelu, check_prim=True, - check_new_ir=True, + check_pir=True, rev_comp_rtol=1e-2, rev_comp_atol=1e-2, cinn_rtol=1e-2, @@ -4754,14 +4879,14 @@ def test_check_grad(self): create_test_act_bf16_class(TestELU) create_test_act_bf16_class(TestCELU) create_test_act_bf16_class(TestReciprocal) -create_test_act_bf16_class(TestLog, check_prim=True, check_new_ir=True) +create_test_act_bf16_class(TestLog, check_prim=True, check_pir=True) if core.is_compiled_with_rocm(): create_test_act_bf16_class(TestLog2) else: create_test_act_bf16_class(TestLog2) create_test_act_bf16_class(TestLog10) create_test_act_bf16_class(TestLog1p) -create_test_act_bf16_class(TestSquare) +create_test_act_bf16_class(TestSquare, check_pir=True) create_test_act_bf16_class(TestPow, check_prim=True) create_test_act_bf16_class(TestPow_API) create_test_act_bf16_class(TestSTanh) @@ -4778,7 +4903,7 @@ def test_check_grad(self): create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True) create_test_act_bf16_class( - TestRsqrt, check_prim=True, check_new_ir=True, check_prim_pir=True + TestRsqrt, check_prim=True, check_pir=True, check_prim_pir=True ) if __name__ == "__main__": diff --git a/test/legacy_test/test_allclose_op.py b/test/legacy_test/test_allclose_op.py index 754a5c81509794..54e78867e7443f 100644 --- a/test/legacy_test/test_allclose_op.py +++ b/test/legacy_test/test_allclose_op.py @@ -53,7 +53,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllcloseOpException(TestAllcloseOp): @@ -61,28 +61,28 @@ def test_check_output(self): def test_rtol_num(): self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64") self.inputs['Atol'] = np.array([1e-08]).astype("float64") - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) self.assertRaises(ValueError, test_rtol_num) def test_rtol_type(): self.inputs['Rtol'] = np.array([5]).astype("int32") self.inputs['Atol'] = np.array([1e-08]).astype("float64") - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) self.assertRaises(ValueError, test_rtol_type) def test_atol_num(): self.inputs['Rtol'] = np.array([1e-05]).astype("float64") self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64") - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) self.assertRaises(ValueError, test_atol_num) def test_atol_type(): self.inputs['Rtol'] = np.array([1e-05]).astype("float64") self.inputs['Atol'] = np.array([8]).astype("int32") - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) self.assertRaises(ValueError, test_atol_type) @@ -200,7 +200,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) class TestAllcloseOpFloat32(TestAllcloseOp): diff --git a/test/legacy_test/test_arange.py b/test/legacy_test/test_arange.py index d22ec561e00012..e71402518696ba 100644 --- a/test/legacy_test/test_arange.py +++ b/test/legacy_test/test_arange.py @@ -48,7 +48,7 @@ def init_config(self): self.case = (0, 1, 0.2) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestFloatArangeOp(TestArangeOp): @@ -65,7 +65,7 @@ def init_config(self): self.case = (0, 5, 1) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) @unittest.skipIf( @@ -99,7 +99,7 @@ def init_config(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) class TestInt32ArangeOp(TestArangeOp): diff --git a/test/legacy_test/test_arg_min_max_op.py b/test/legacy_test/test_arg_min_max_op.py index 09425be02fc53e..ede4a54a244ed5 100644 --- a/test/legacy_test/test_arg_min_max_op.py +++ b/test/legacy_test/test_arg_min_max_op.py @@ -42,7 +42,7 @@ def setUp(self): self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} def test_check_output(self): - self.check_output(check_cinn=True) + self.check_output(check_cinn=True, check_pir=True) class TestCase0(BaseTestCase): @@ -122,7 +122,7 @@ def setUp(self): self.outputs = {'Out': np.argmax(x, axis=self.axis)} def test_check_output(self): - self.check_output_with_place(paddle.CUDAPlace(0)) + self.check_output_with_place(paddle.CUDAPlace(0), check_pir=True) class TestArgMaxBF16OP(TestArgMinBF16OP): diff --git a/test/legacy_test/test_assign_op.py b/test/legacy_test/test_assign_op.py index 4a9ff9308f7b82..270fe45ffe7429 100644 --- a/test/legacy_test/test_assign_op.py +++ b/test/legacy_test/test_assign_op.py @@ -24,6 +24,7 @@ from paddle import base from paddle.base import Program, core, program_guard from paddle.base.backward import append_backward +from paddle.pir_utils import test_with_pir_api class TestAssignOp(op_test.OpTest): @@ -42,12 +43,12 @@ def init_input_configs(self): def test_forward(self): paddle.enable_static() - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) paddle.disable_static() def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) paddle.disable_static() @@ -71,12 +72,12 @@ def setUp(self): def test_forward(self): paddle.enable_static() - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) paddle.disable_static() def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) paddle.disable_static() @@ -97,12 +98,12 @@ def setUp(self): def test_forward(self): paddle.enable_static() - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) paddle.disable_static() def test_backward(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) paddle.disable_static() @@ -275,6 +276,7 @@ def test_assign_bfp16(self): class TestAssignOpErrorApi(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() with program_guard(Program(), Program()): @@ -288,6 +290,7 @@ def test_errors(self): self.assertRaises(TypeError, paddle.assign, x2) paddle.disable_static() + @test_with_pir_api def test_type_error(self): paddle.enable_static() with program_guard(Program(), Program()): diff --git a/test/legacy_test/test_assign_value_op.py b/test/legacy_test/test_assign_value_op.py index e6828351628702..3bdc97f4247f3a 100644 --- a/test/legacy_test/test_assign_value_op.py +++ b/test/legacy_test/test_assign_value_op.py @@ -54,7 +54,7 @@ def init_data(self): self.attrs["fp32_values"] = [float(v) for v in self.value.flat] def test_forward(self): - self.check_output(check_cinn=True, check_new_ir=False) + self.check_output(check_cinn=True, check_pir=True) class TestAssignValueOp2(TestAssignValueOp): @@ -105,6 +105,18 @@ def test_assign(self): np.testing.assert_array_equal(fetched_x, self.value) self.assertEqual(fetched_x.dtype, self.value.dtype) + def test_pir_assign(self): + with paddle.pir_utils.IrGuard(): + main_program = paddle.pir.Program() + with paddle.pir.core.program_guard(main_program): + x = paddle.zeros(shape=[1], dtype=self.dtype) + paddle.assign(self.value, output=x) + + exe = base.Executor(self.place) + [fetched_x] = exe.run(main_program, feed={}, fetch_list=[x]) + np.testing.assert_array_equal(fetched_x, self.value) + self.assertEqual(fetched_x.dtype, self.value.dtype) + class TestAssignApi2(TestAssignApi): def init_dtype(self): diff --git a/test/legacy_test/test_bitwise_op.py b/test/legacy_test/test_bitwise_op.py index a5040b434b260a..21a7abe812ad7a 100644 --- a/test/legacy_test/test_bitwise_op.py +++ b/test/legacy_test/test_bitwise_op.py @@ -43,7 +43,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): pass diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index 79a8926162fa40..d9999aad5f9ddc 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -52,7 +52,7 @@ def init_shapes(self): self.input_shape = [10, 10] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_grad(self): self.check_grad( @@ -60,7 +60,7 @@ def test_grad(self): ['Out'], check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -84,7 +84,7 @@ def setUp(self): self.public_python_api = cast_wrapper def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_grad(self): self.check_grad( @@ -92,7 +92,7 @@ def test_grad(self): ['Out'], check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -111,7 +111,7 @@ def setUp(self): self.public_python_api = cast_wrapper def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_grad(self): self.check_grad( @@ -119,7 +119,7 @@ def test_grad(self): ['Out'], check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -146,7 +146,7 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_grad(self): self.check_grad( @@ -154,7 +154,7 @@ def test_grad(self): ['Out'], check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -181,7 +181,7 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_grad(self): self.check_grad( @@ -189,7 +189,7 @@ def test_grad(self): ['Out'], check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) diff --git a/test/legacy_test/test_clip_op.py b/test/legacy_test/test_clip_op.py index 2de2d94047363d..6ac2b0a17e7ca4 100644 --- a/test/legacy_test/test_clip_op.py +++ b/test/legacy_test/test_clip_op.py @@ -55,12 +55,12 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) paddle.disable_static() def test_check_grad_normal(self): paddle.enable_static() - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) paddle.disable_static() def initTestCase(self): @@ -194,14 +194,14 @@ def test_check_output(self): if paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) paddle.enable_static() - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) paddle.disable_static() def test_check_grad_normal(self): if paddle.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) paddle.enable_static() - self.check_grad_with_place(place, ['X'], 'Out', check_new_ir=True) + self.check_grad_with_place(place, ['X'], 'Out', check_pir=True) paddle.disable_static() def initTestCase(self): diff --git a/test/legacy_test/test_collective_api_base.py b/test/legacy_test/test_collective_api_base.py index 08de4a1be9a322..a431d77cdfe713 100644 --- a/test/legacy_test/test_collective_api_base.py +++ b/test/legacy_test/test_collective_api_base.py @@ -359,6 +359,7 @@ def check_with_place( "PATH_ID": path_id, "DTYPE": dtype, "REDUCE_TYPE": str(reduce_type), + "FLAGS_dynamic_static_unified_comm": "0", } required_envs.update(additional_envs) required_envs.update(need_envs) @@ -608,16 +609,23 @@ def convertbf16(origin): send_ptr2 = send_ptr2 + global_expert_count2[idx] result1 = [] result2 = [] + + def is_empyt_list(x): + if isinstance(x, list) and len(x) == 0: + return True + return False + for i in range(tot_expert): for arr in output1[i]: - if arr == []: + if is_empyt_list(arr): continue result1.append(arr) for i in range(tot_expert): for arr in output2[i]: - if arr == []: + if is_empyt_list(arr): continue result2.append(arr) + if result1 == []: output1 = np.array([]) else: diff --git a/test/legacy_test/test_collective_base.py b/test/legacy_test/test_collective_base.py index 9d3a602b8d051a..544cee3ac0e7ec 100644 --- a/test/legacy_test/test_collective_base.py +++ b/test/legacy_test/test_collective_base.py @@ -266,7 +266,7 @@ def check_with_place( "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), "GLOG_v": "3", "NCCL_P2P_DISABLE": "1", - "Flags_dynamic_static_unified_comm": "False", + "FLAGS_dynamic_static_unified_comm": "0", "DTYPE": "float32", } required_envs.update(need_envs) diff --git a/test/legacy_test/test_compare_op.py b/test/legacy_test/test_compare_op.py index 2bae19d180e2c6..91dce088ef88ef 100755 --- a/test/legacy_test/test_compare_op.py +++ b/test/legacy_test/test_compare_op.py @@ -20,10 +20,11 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api -def create_test_class(op_type, typename, callback, check_new_ir=False): +def create_test_class(op_type, typename, callback, check_pir=False): class Cls(op_test.OpTest): def setUp(self): a = numpy.random.random(size=(10, 7)).astype(typename) @@ -35,11 +36,13 @@ def setUp(self): self.op_type = op_type def test_output(self): - self.check_output(check_cinn=True, check_new_ir=check_new_ir) + self.check_output(check_cinn=True, check_pir=check_pir) def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32') y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32') a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16') @@ -58,14 +61,14 @@ def test_errors(self): if _type_name == 'float16' and (not core.is_compiled_with_cuda()): continue - create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) - create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) - create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('less_than', _type_name, lambda _a, _b: _a < _b, True) + create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b, True) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b, True) create_test_class( 'greater_equal', _type_name, lambda _a, _b: _a >= _b, True ) create_test_class('equal', _type_name, lambda _a, _b: _a == _b, True) - create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b, True) def create_paddle_case(op_type, callback): @@ -79,9 +82,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.place = paddle.CUDAPlace(0) + @test_with_pir_api def test_api(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[4], dtype='int64') y = paddle.static.data(name='y', shape=[4], dtype='int64') op = eval("paddle.%s" % (self.op_type)) @@ -93,10 +99,13 @@ def test_api(self): ) self.assertEqual((res == self.real_result).all(), True) + @test_with_pir_api def test_api_float(self): if self.op_type == "equal": paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[4], dtype='int64') y = paddle.static.data(name='y', shape=[], dtype='int64') op = eval("paddle.%s" % (self.op_type)) @@ -290,9 +299,12 @@ def test_dynamic_api_bool(self): self.assertEqual((out.numpy() == self.real_result).all(), True) paddle.enable_static() + @test_with_pir_api def test_broadcast_api_1(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data( name='x', shape=[1, 2, 1, 3], dtype='int32' ) @@ -308,9 +320,12 @@ def test_broadcast_api_1(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_broadcast_api_2(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='int32') y = paddle.static.data( name='y', shape=[1, 2, 1, 3], dtype='int32' @@ -326,9 +341,12 @@ def test_broadcast_api_2(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_broadcast_api_3(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[5], dtype='int32') y = paddle.static.data(name='y', shape=[3, 1], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -342,9 +360,12 @@ def test_broadcast_api_3(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_zero_dim_api_1(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-3, 3, shape=[], dtype='int32') y = paddle.randint(-3, 3, shape=[], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -358,9 +379,12 @@ def test_zero_dim_api_1(self): real_result = callback(x_np, y_np) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_zero_dim_api_2(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') y = paddle.randint(-3, 3, shape=[], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -374,9 +398,12 @@ def test_zero_dim_api_2(self): real_result = callback(x_np, y_np) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_zero_dim_api_3(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.randint(-3, 3, shape=[], dtype='int32') y = paddle.randint(-3, 3, shape=[2, 3, 4], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -390,9 +417,12 @@ def test_zero_dim_api_3(self): real_result = callback(x_np, y_np) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_bool_api_4(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') y = paddle.static.data(name='y', shape=[3, 1], dtype='bool') op = eval("paddle.%s" % (self.op_type)) @@ -406,9 +436,12 @@ def test_bool_api_4(self): ) self.assertEqual((res == real_result).all(), True) + @test_with_pir_api def test_bool_broadcast_api_4(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') y = paddle.static.data(name='y', shape=[1], dtype='bool') op = eval("paddle.%s" % (self.op_type)) @@ -424,7 +457,9 @@ def test_bool_broadcast_api_4(self): def test_attr_name(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x = paddle.static.data(name='x', shape=[-1, 4], dtype='int32') y = paddle.static.data(name='y', shape=[-1, 4], dtype='int32') op = eval("paddle.%s" % (self.op_type)) @@ -445,7 +480,7 @@ def test_attr_name(self): # add bf16 tests -def create_bf16_case(op_type, callback, check_new_ir=False): +def create_bf16_case(op_type, callback, check_pir=False): class TestCompareOpBF16Op(op_test.OpTest): def setUp(self): self.op_type = op_type @@ -462,25 +497,27 @@ def setUp(self): self.outputs = {'Out': real_result} def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=check_new_ir) + self.check_output(check_cinn=True, check_pir=check_pir) cls_name = f"BF16TestCase_{op_type}" TestCompareOpBF16Op.__name__ = cls_name globals()[cls_name] = TestCompareOpBF16Op -create_bf16_case('less_than', lambda _a, _b: _a < _b) -create_bf16_case('less_equal', lambda _a, _b: _a <= _b) -create_bf16_case('greater_than', lambda _a, _b: _a > _b) +create_bf16_case('less_than', lambda _a, _b: _a < _b, True) +create_bf16_case('less_equal', lambda _a, _b: _a <= _b, True) +create_bf16_case('greater_than', lambda _a, _b: _a > _b, True) create_bf16_case('greater_equal', lambda _a, _b: _a >= _b, True) create_bf16_case('equal', lambda _a, _b: _a == _b, True) -create_bf16_case('not_equal', lambda _a, _b: _a != _b) +create_bf16_case('not_equal', lambda _a, _b: _a != _b, True) class TestCompareOpError(unittest.TestCase): def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): # The input x and y of compare_op must be Variable. x = paddle.static.data(name='x', shape=[-1, 1], dtype="float32") y = base.create_lod_tensor( @@ -490,9 +527,12 @@ def test_errors(self): class API_TestElementwise_Equal(unittest.TestCase): + @test_with_pir_api def test_api(self): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): label = paddle.assign(np.array([3, 3], dtype="int32")) limit = paddle.assign(np.array([3, 2], dtype="int32")) out = paddle.equal(x=label, y=limit) @@ -501,7 +541,9 @@ def test_api(self): (res,) = exe.run(fetch_list=[out]) self.assertEqual((res == np.array([True, False])).all(), True) - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): label = paddle.assign(np.array([3, 3], dtype="int32")) limit = paddle.assign(np.array([3, 3], dtype="int32")) out = paddle.equal(x=label, y=limit) @@ -510,9 +552,12 @@ def test_api(self): (res,) = exe.run(fetch_list=[out]) self.assertEqual((res == np.array([True, True])).all(), True) + @test_with_pir_api def test_api_fp16(self): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): label = paddle.to_tensor([3, 3], dtype="float16") limit = paddle.to_tensor([3, 2], dtype="float16") out = paddle.equal(x=label, y=limit) @@ -524,6 +569,7 @@ def test_api_fp16(self): class API_TestElementwise_Greater_Than(unittest.TestCase): + @test_with_pir_api def test_api_fp16(self): paddle.enable_static() with paddle.static.program_guard( @@ -540,17 +586,21 @@ def test_api_fp16(self): class TestCompareOpPlace(unittest.TestCase): + @test_with_pir_api def test_place_1(self): paddle.enable_static() place = paddle.CPUPlace() if core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) - label = paddle.assign(np.array([3, 3], dtype="int32")) - limit = paddle.assign(np.array([3, 2], dtype="int32")) - out = paddle.less_than(label, limit) - exe = base.Executor(place) - (res,) = exe.run(fetch_list=[out]) - self.assertEqual((res == np.array([False, False])).all(), True) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + label = paddle.assign(np.array([3, 3], dtype="int32")) + limit = paddle.assign(np.array([3, 2], dtype="int32")) + out = paddle.less_than(label, limit) + exe = base.Executor(place) + (res,) = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([False, False])).all(), True) def test_place_2(self): place = paddle.CPUPlace() diff --git a/test/legacy_test/test_concat_op.py b/test/legacy_test/test_concat_op.py index 153e1cc06d3085..efa87c36095706 100644 --- a/test/legacy_test/test_concat_op.py +++ b/test/legacy_test/test_concat_op.py @@ -53,9 +53,9 @@ def get_dtype(self): def test_check_output(self): if self.dtype == np.uint16: place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) else: - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): if self.dtype == np.uint16: @@ -65,7 +65,7 @@ def test_check_grad(self): ['x0'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) self.check_grad_with_place( @@ -73,7 +73,7 @@ def test_check_grad(self): ['x1'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) self.check_grad_with_place( @@ -81,7 +81,7 @@ def test_check_grad(self): ['x2'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) else: @@ -89,21 +89,21 @@ def test_check_grad(self): ['x0'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) self.check_grad( ['x1'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) self.check_grad( ['x2'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -199,12 +199,12 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_new_ir=False) + self.check_output(check_pir=False) def test_check_grad(self): - self.check_grad(['x0'], 'Out', check_new_ir=False) - self.check_grad(['x1'], 'Out', check_new_ir=False) - self.check_grad(['x2'], 'Out', check_new_ir=False) + self.check_grad(['x0'], 'Out', check_pir=False) + self.check_grad(['x1'], 'Out', check_pir=False) + self.check_grad(['x2'], 'Out', check_pir=False) def init_test_data(self): self.x0 = np.random.random([100]).astype(self.dtype) @@ -243,28 +243,28 @@ def get_dtype(self): return "float64" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['x0'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) self.check_grad( ['x1'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) self.check_grad( ['x2'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -319,19 +319,13 @@ def test_check_grad(self): return if self.dtype == np.uint16: place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['x0'], 'Out', check_new_ir=True - ) - self.check_grad_with_place( - place, ['x1'], 'Out', check_new_ir=True - ) - self.check_grad_with_place( - place, ['x2'], 'Out', check_new_ir=True - ) + self.check_grad_with_place(place, ['x0'], 'Out', check_pir=True) + self.check_grad_with_place(place, ['x1'], 'Out', check_pir=True) + self.check_grad_with_place(place, ['x2'], 'Out', check_pir=True) else: - self.check_grad(['x0'], 'Out', check_new_ir=True) - self.check_grad(['x1'], 'Out', check_new_ir=True) - self.check_grad(['x2'], 'Out', check_new_ir=True) + self.check_grad(['x0'], 'Out', check_pir=True) + self.check_grad(['x1'], 'Out', check_pir=True) + self.check_grad(['x2'], 'Out', check_pir=True) cls_name = "{}_{}".format(parent.__name__, "AxisTensor") TestConcatAxisTensor.__name__ = cls_name @@ -388,7 +382,7 @@ def test_check_grad(self): place, ['x0'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -396,7 +390,7 @@ def test_check_grad(self): place, ['x1'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -404,7 +398,7 @@ def test_check_grad(self): place, ['x2'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -412,21 +406,21 @@ def test_check_grad(self): self.check_grad( ['x0'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) self.check_grad( ['x1'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) self.check_grad( ['x2'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -493,7 +487,7 @@ def test_check_grad(self): place, ['x0'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -501,7 +495,7 @@ def test_check_grad(self): place, ['x1'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -509,7 +503,7 @@ def test_check_grad(self): place, ['x2'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -517,21 +511,21 @@ def test_check_grad(self): self.check_grad( ['x0'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) self.check_grad( ['x1'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) self.check_grad( ['x2'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) diff --git a/test/legacy_test/test_conv2d_layer.py b/test/legacy_test/test_conv2d_layer.py index 4290a7352afed9..a347472bd2a873 100644 --- a/test/legacy_test/test_conv2d_layer.py +++ b/test/legacy_test/test_conv2d_layer.py @@ -218,8 +218,53 @@ def paddle_nn_layer(self): t1 = x_var.gradient() return y_np, t1 + def run_Conv2D_static(self, place): + paddle.seed(2023) + main = base.Program() + start = base.Program() + with base.unique_name.guard(): + with base.program_guard(main, start): + x_var = paddle.static.data( + "input", self.input.shape, dtype=self.dtype + ) + conv = nn.Conv2D( + self.num_channels, + self.num_filters, + self.filter_size, + padding=self.padding, + padding_mode=self.padding_mode, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + y_var = conv(x_var) + feed_dict = {"input": self.input} + exe = base.Executor(place) + exe.run(start) + (y_np,) = exe.run(main, feed=feed_dict, fetch_list=[y_var]) + return y_np + + def run_Conv2D_dygraph(self): + paddle.seed(2023) + x_var = paddle.to_tensor(self.input) + x_var.stop_gradient = False + conv = nn.Conv2D( + self.num_channels, + self.num_filters, + self.filter_size, + padding=self.padding, + padding_mode=self.padding_mode, + stride=self.stride, + dilation=self.dilation, + groups=self.groups, + data_format=self.data_format, + ) + y_var = conv(x_var) + y_np = y_var.numpy() + return y_np + def _test_equivalence(self, place): - place = base.CPUPlace() result1 = self.base_layer(place) result2 = self.functional(place) with dg.guard(place): @@ -227,13 +272,22 @@ def _test_equivalence(self, place): np.testing.assert_array_almost_equal(result1, result2) np.testing.assert_array_almost_equal(result2, result3) + def _test_equivalence_in_pir(self, place): + with paddle.pir_utils.IrGuard(): + result1 = self.run_Conv2D_static(place) + with dg.guard(place): + result2 = self.run_Conv2D_dygraph() + np.testing.assert_array_almost_equal(result1, result2) + def runTest(self): place = base.CPUPlace() self._test_equivalence(place) + self._test_equivalence_in_pir(place) if base.core.is_compiled_with_cuda(): place = base.CUDAPlace(0) self._test_equivalence(place) + self._test_equivalence_in_pir(place) class Conv2DErrorTestCase(Conv2DTestCase): diff --git a/test/legacy_test/test_dist_base.py b/test/legacy_test/test_dist_base.py index db7d490e3a5afe..b4d8257503d401 100755 --- a/test/legacy_test/test_dist_base.py +++ b/test/legacy_test/test_dist_base.py @@ -1692,6 +1692,7 @@ def _get_required_envs(self, check_error_log=False, need_envs={}): "NCCL_P2P_DISABLE": "1", "NCCL_SHM_DISABLE": "1", "FLAGS_new_executor_static_build": "1", + "FLAGS_dynamic_static_unified_comm": "0", } if check_error_log: diff --git a/test/legacy_test/test_dist_hapi_model.py b/test/legacy_test/test_dist_hapi_model.py index 1e5ec1d341f71f..03a92d6f3cbc91 100644 --- a/test/legacy_test/test_dist_hapi_model.py +++ b/test/legacy_test/test_dist_hapi_model.py @@ -75,6 +75,7 @@ def start_local_trainers( "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "FLAGS_dynamic_static_unified_comm": "0", } current_env.update(proc_env) diff --git a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py index 32ee6fd8b39581..62a94832d1ae9e 100644 --- a/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py +++ b/test/legacy_test/test_distributed_fused_lamb_op_with_clip.py @@ -68,6 +68,7 @@ def run_test( os.environ['MAX_GLOBAL_NORM'] = str(max_global_norm) os.environ['GRADIENT_MERGE_STEPS'] = str(gradient_merge_steps) os.environ['USE_MASTER_ACC_GRAD'] = str(1 if use_master_acc_grad else 0) + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" os.environ.update(need_env) touch_file_env = 'SUCCESS_TOUCH_FILE' diff --git a/test/legacy_test/test_dot_op.py b/test/legacy_test/test_dot_op.py index d3035ac174f798..3b1a216add6da3 100644 --- a/test/legacy_test/test_dot_op.py +++ b/test/legacy_test/test_dot_op.py @@ -37,7 +37,7 @@ def setUp(self): self.attrs = {} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): if core.is_compiled_with_rocm(): @@ -45,10 +45,10 @@ def test_check_grad_normal(self): ['X', 'Y'], 'Out', user_defined_grads=[self.inputs['Y'], self.inputs['X']], - check_new_ir=True, + check_pir=True, ) else: - self.check_grad(['X', 'Y'], 'Out', check_new_ir=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): if core.is_compiled_with_rocm(): @@ -57,12 +57,10 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), user_defined_grads=[self.inputs['X']], - check_new_ir=True, + check_pir=True, ) else: - self.check_grad( - ['Y'], 'Out', no_grad_set=set("X"), check_new_ir=True - ) + self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=True) def test_check_grad_ingore_y(self): if core.is_compiled_with_rocm(): @@ -71,12 +69,10 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set('Y'), user_defined_grads=[self.inputs['Y']], - check_new_ir=True, + check_pir=True, ) else: - self.check_grad( - ['X'], 'Out', no_grad_set=set('Y'), check_new_ir=True - ) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) def init_input_output(self): self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype) @@ -125,13 +121,13 @@ def init_input_output(self): self.out = np.sum(self.x * self.y, axis=1) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_new_ir=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_new_ir=True) + self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=True) def test_check_grad_ingore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_new_ir=True) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) class TestDotOpError(unittest.TestCase): @@ -234,20 +230,22 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=0.125) + self.check_output_with_place(place, atol=0.125, check_pir=True) def test_check_grad_normal(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad_with_place(place, ['X', 'Y'], 'Out') + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', check_pir=True + ) def test_check_grad_ingore_x(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_grad_with_place( - place, ['Y'], 'Out', no_grad_set=set("X") + place, ['Y'], 'Out', no_grad_set=set("X"), check_pir=True ) def test_check_grad_ingore_y(self): @@ -255,7 +253,7 @@ def test_check_grad_ingore_y(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_grad_with_place( - place, ['X'], 'Out', no_grad_set=set("Y") + place, ['X'], 'Out', no_grad_set=set("Y"), check_pir=True ) def init_input_output(self): @@ -306,7 +304,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) if core.is_bfloat16_supported(place): - self.check_output_with_place(place, atol=0.5) + self.check_output_with_place(place, atol=0.5, check_pir=True) def test_check_grad_normal(self): if core.is_compiled_with_cuda(): @@ -317,6 +315,7 @@ def test_check_grad_normal(self): ['X', 'Y'], 'Out', user_defined_grads=[self.inputs['Y'], self.inputs['X']], + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -329,6 +328,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), user_defined_grads=[self.inputs['X']], + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -341,6 +341,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set("Y"), user_defined_grads=[self.inputs['Y']], + check_pir=True, ) def init_input_output(self): @@ -378,6 +379,7 @@ def test_check_grad_normal(self): self.y / self.y.shape[0], self.x / self.x.shape[0], ], + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -390,6 +392,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), user_defined_grads=[self.x / self.x.shape[0]], + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -402,6 +405,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set("Y"), user_defined_grads=[self.y / self.y.shape[0]], + check_pir=True, ) diff --git a/test/legacy_test/test_dropout_op.py b/test/legacy_test/test_dropout_op.py index f65e4d2b4b855b..433b9eeff7056d 100644 --- a/test/legacy_test/test_dropout_op.py +++ b/test/legacy_test/test_dropout_op.py @@ -26,6 +26,7 @@ from paddle.base.executor import scope_guard from paddle.decomposition import decompose from paddle.incubate.autograd import primapi +from paddle.pir_utils import test_with_pir_api def dropout_wapper( @@ -84,13 +85,11 @@ def setUp(self): self.enable_check_static_comp = False def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) def test_check_grad_normal(self): # Now in dy2st mode x_grad = [], so set check_prim=False - self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) class TestDropoutOp_ZeroDim(TestDropoutOp): @@ -129,13 +128,11 @@ def setUp(self): self.enable_check_static_comp = False def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) def test_check_grad_normal(self): # Now in dy2st mode x_grad = [], so set check_prim=False - self.check_grad(['X'], 'Out', check_prim=False, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=False, check_pir=True) class TestDropoutOp2(TestDropoutOp): @@ -198,9 +195,7 @@ def setUp(self): } def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) @skip_check_grad_ci(reason="For inference, check_grad is not required.") @@ -217,9 +212,7 @@ def setUp(self): } def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) class TestDropoutOp6(TestDropoutOp): @@ -281,9 +274,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X']} def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) @skip_check_grad_ci(reason="For inference, check_grad is not required.") @@ -302,9 +293,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X']} def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) class TestDropoutOpWithSeed(OpTest): @@ -331,9 +320,7 @@ def setUp(self): def test_check_output(self): # ir backward don't support of variable derivation of itself - self.check_output( - check_prim=True, check_prim_pir=False, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=False, check_pir=True) def test_check_grad_normal(self): # Now in dy2st mode x_grad = [], so set check_prim=False @@ -342,7 +329,7 @@ def test_check_grad_normal(self): 'Out', max_relative_error=0.05, check_prim=False, - check_new_ir=True, + check_pir=True, ) @@ -380,11 +367,11 @@ def test_check_output(self): atol=1e-3, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) @unittest.skipIf( @@ -419,9 +406,7 @@ def setUp(self): } def test_check_output(self): - self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True - ) + self.check_output(check_prim=True, check_prim_pir=True, check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -429,7 +414,7 @@ def test_check_grad_normal(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -539,9 +524,11 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): input = paddle.static.data( name="input", shape=[-1, -1], dtype="float32" ) @@ -590,7 +577,6 @@ def check_static_result(self, place): training=False, mode='downscale_in_infer', ) - res10 = paddle.nn.functional.dropout(x=input, p=1.0, training=True) res11 = paddle.nn.functional.dropout(x=input, p=0.0) res12 = paddle.nn.functional.dropout( x=input, @@ -600,13 +586,8 @@ def check_static_result(self, place): mode='upscale_in_train', ) - res13 = paddle.nn.functional.dropout( - x=input, p=0.7, axis=1, training=True, mode='upscale_in_train' - ) - in_np = np.ones([40, 40]).astype("float32") res_np = in_np - res_np2 = np.zeros_like(in_np) exe = base.Executor(place) res_list = [ @@ -624,26 +605,39 @@ def check_static_result(self, place): ] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + + @test_with_pir_api + def check_static_result2(self, place): + paddle.enable_static() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog): + input = paddle.static.data( + name="input", shape=[-1, -1], dtype="float32" + ) + res10 = paddle.nn.functional.dropout(x=input, p=1.0, training=True) + res13 = paddle.nn.functional.dropout( + x=input, p=0.7, axis=1, training=True, mode='upscale_in_train' + ) + in_np = np.ones([40, 40]).astype("float32") + res_np2 = np.zeros_like(in_np) + + exe = base.Executor(place) fetches2 = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, - fetch_list=[res10], + fetch_list=[res10, res13], ) np.testing.assert_allclose(fetches2[0], res_np2, rtol=1e-05) - fetches3 = exe.run( - base.default_main_program(), - feed={"input": in_np}, - fetch_list=[res13], - ) def test_static(self): for place in self.places: self.check_static_result(place=place) + self.check_static_result2(place=place) def test_dygraph(self): for place in self.places: @@ -785,6 +779,13 @@ def test_dtype(): self.assertRaises(TypeError, test_dtype) + @test_with_pir_api + def test_errors2(self): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + def test_pdtype(): # p should be int or float x2 = paddle.static.data( @@ -877,9 +878,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[2, 3, 4, 5], dtype="float32" ) @@ -897,7 +901,7 @@ def check_static_result(self, place): res_list = [res1, res2] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) @@ -927,9 +931,12 @@ def test_dygraph(self): class TestDropout2DFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): def test_xdim(): # dimentions of x should be 4 @@ -970,6 +977,7 @@ def test_dygraph(self): result.numpy(), result_np, rtol=1e-05 ) + @test_with_pir_api def test_static_fp16_with_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -1002,9 +1010,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): paddle.enable_static() - with base.program_guard(base.Program(), base.Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[2, 3, 4, 5, 6], dtype="float32" ) @@ -1022,7 +1033,7 @@ def check_static_result(self, place): res_list = [res1, res2] for res in res_list: fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, fetch_list=[res], ) @@ -1052,9 +1063,12 @@ def test_dygraph(self): class TestDropout3DFAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): paddle.enable_static() - with program_guard(Program(), Program()): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): def test_xdim(): # dimentions of x should be 5 @@ -1103,8 +1117,12 @@ def setUp(self): if core.is_compiled_with_cuda(): self.places.append(base.CUDAPlace(0)) + @test_with_pir_api def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): input = paddle.static.data( name="input", shape=[40, 40], dtype="float32" ) @@ -1119,20 +1137,15 @@ def check_static_result(self, place): res_np3 = np.zeros_like(in_np) exe = base.Executor(place) - res_list = [res1, res2] - for res in res_list: - fetches = exe.run( - base.default_main_program(), - feed={"input": in_np}, - fetch_list=[res], - ) - np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + fetches = exe.run( - base.default_main_program(), + main_prog, feed={"input": in_np}, - fetch_list=[res3], + fetch_list=[res1, res2, res3], ) - np.testing.assert_allclose(fetches[0], res_np3, rtol=1e-05) + np.testing.assert_allclose(fetches[0], res_np, rtol=1e-05) + np.testing.assert_allclose(fetches[1], res_np, rtol=1e-05) + np.testing.assert_allclose(fetches[2], res_np3, rtol=1e-05) def test_static(self): for place in self.places: @@ -1171,6 +1184,13 @@ def test_Variable(): self.assertRaises(TypeError, test_Variable) + @test_with_pir_api + def test_errors2(self): + paddle.enable_static() + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + def test_dtype(): # the input dtype of dropout must be float32 or float64 xr = paddle.static.data( @@ -1219,6 +1239,7 @@ def test_dygraph(self): result.numpy(), result_np, rtol=1e-05 ) + @test_with_pir_api def test_static_fp16_gpu(self): if paddle.base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) @@ -1378,9 +1399,9 @@ def api_case(self, x): def run_static(self, x): paddle.seed(2022) - main_program = Program() paddle.enable_static() - with program_guard(main_program): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): input = paddle.static.data(shape=x.shape, name='x', dtype='float32') out = self.api_case(input) sgd = paddle.optimizer.SGD(learning_rate=0.1) @@ -2098,7 +2119,9 @@ def test_static_comp(self): ) core._set_prim_forward_enabled(True) - [output] = decompose(mp, [output]) # decompose forward + [output] = decompose( + mp, [output], whitelist={"pd_op.dropout"} + ) # decompose forward self.assertTrue( 'pd_op.dropout' not in [op.name() for op in mp.global_block().ops] diff --git a/test/legacy_test/test_eig_op.py b/test/legacy_test/test_eig_op.py index c5ba7262902c77..a7fa0665c645ac 100644 --- a/test/legacy_test/test_eig_op.py +++ b/test/legacy_test/test_eig_op.py @@ -183,7 +183,7 @@ def init_grad(self): def test_check_output(self): self.check_output_with_place_customized( - checker=self.checker, place=core.CPUPlace() + checker=self.checker, place=core.CPUPlace(), check_pir=True ) def test_check_grad(self): @@ -193,6 +193,7 @@ def test_check_grad(self): ['Eigenvalues', 'Eigenvectors'], user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_w, self.grad_v], + check_pir=True, ) diff --git a/test/legacy_test/test_eigvals_op.py b/test/legacy_test/test_eigvals_op.py index 6f3f126b2db3ed..379603234d5afe 100644 --- a/test/legacy_test/test_eigvals_op.py +++ b/test/legacy_test/test_eigvals_op.py @@ -37,6 +37,7 @@ class TestEigvalsOp(OpTest): def setUp(self): np.random.seed(0) paddle.enable_static() + self.python_api = paddle.linalg.eigvals self.op_type = "eigvals" self.set_dtype() self.set_input_dims() @@ -67,7 +68,7 @@ def set_input_data(self): def test_check_output(self): self.__class__.no_need_check_grad = True self.check_output_with_place_customized( - checker=self.verify_output, place=core.CPUPlace() + checker=self.verify_output, place=core.CPUPlace(), check_pir=True ) def verify_output(self, outs): diff --git a/test/legacy_test/test_elementwise_add_op.py b/test/legacy_test/test_elementwise_add_op.py index f5013d298e170a..d3039ca365d34c 100644 --- a/test/legacy_test/test_elementwise_add_op.py +++ b/test/legacy_test/test_elementwise_add_op.py @@ -56,7 +56,7 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode self.check_output( check_dygraph=self.check_dygraph(), - check_new_ir=self.check_dygraph(), + check_pir=self.check_dygraph(), ) def test_check_grad_normal(self): @@ -69,7 +69,7 @@ def test_check_grad_normal(self): check_dygraph=self.check_dygraph(), check_prim=self.check_prim, check_prim_pir=self.check_dygraph(), - check_new_ir=self.check_dygraph(), + check_pir=self.check_dygraph(), ) def test_check_grad_ingore_x(self): @@ -83,7 +83,7 @@ def test_check_grad_ingore_x(self): check_dygraph=self.check_dygraph(), check_prim=self.check_prim, check_prim_pir=self.check_dygraph(), - check_new_ir=self.check_dygraph(), + check_pir=self.check_dygraph(), ) def test_check_grad_ingore_y(self): @@ -97,7 +97,7 @@ def test_check_grad_ingore_y(self): check_dygraph=self.check_dygraph(), check_prim=self.check_prim, check_prim_pir=self.check_dygraph(), - check_new_ir=self.check_dygraph(), + check_pir=self.check_dygraph(), ) def init_input_output(self): @@ -153,7 +153,7 @@ def test_check_output(self): place, atol=1e-3, check_dygraph=self.check_dygraph(), - check_new_ir=self.check_dygraph(), + check_pir=self.check_dygraph(), ) def test_check_grad_normal(self): @@ -169,7 +169,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -181,7 +181,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -215,7 +215,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad_normal(self): place = core.CUDAPlace(0) @@ -225,7 +225,7 @@ def test_check_grad_normal(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -237,7 +237,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -249,7 +249,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def if_enable_cinn(self): @@ -744,16 +744,16 @@ def init_input_output(self): self.out = self.x + self.y def test_check_output(self): - self.check_output(check_new_ir=False) + self.check_output(check_pir=False) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_new_ir=False) + self.check_grad(['X', 'Y'], 'Out', check_pir=False) def test_check_grad_ingore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_new_ir=False) + self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=False) def test_check_grad_ingore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_new_ir=False) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=False) class TestRealComplexElementwiseAddOp(TestComplexElementwiseAddOp): @@ -772,7 +772,11 @@ def test_static_add(self): b = paddle.full([4, 5, 6], True, dtype='bool') c = a + b self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) - paddle.enable_static() + with paddle.pir_utils.IrGuard(): + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.DataType.FLOAT32) def test_dygraph_add(self): paddle.disable_static() diff --git a/test/legacy_test/test_elementwise_div_op.py b/test/legacy_test/test_elementwise_div_op.py index bb1676bb00afbe..c17d2b8946a7c5 100644 --- a/test/legacy_test/test_elementwise_div_op.py +++ b/test/legacy_test/test_elementwise_div_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def broadcast_wrapper(shape=[1, 10, 12, 1]): @@ -98,9 +99,9 @@ def compute_gradient_y(self, grad_out, out, y): def test_check_output(self): if self.place is None: - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) else: - self.check_output_with_place(self.place, check_new_ir=True) + self.check_output_with_place(self.place, check_pir=True) def test_check_gradient(self): check_list = [] @@ -128,11 +129,11 @@ def test_check_gradient(self): 'check_prim_pir': self.check_prim_pir, } if self.place is None: - self.check_grad(*check_args, **check_kwargs, check_new_ir=True) + self.check_grad(*check_args, **check_kwargs, check_pir=True) else: check_args.insert(0, self.place) self.check_grad_with_place( - *check_args, **check_kwargs, check_new_ir=True + *check_args, **check_kwargs, check_pir=True ) @@ -221,11 +222,11 @@ def test_check_gradient(self): 'check_prim_pir': self.check_prim_pir, } if self.place is None: - self.check_grad(*check_args, **check_kwargs, check_new_ir=True) + self.check_grad(*check_args, **check_kwargs, check_pir=True) else: check_args.insert(0, self.place) self.check_grad_with_place( - *check_args, **check_kwargs, check_new_ir=True + *check_args, **check_kwargs, check_pir=True ) def if_check_prim(self): @@ -279,11 +280,11 @@ def test_check_gradient(self): 'check_dygraph': self.check_dygraph, } if self.place is None: - self.check_grad(*check_args, **check_kwargs, check_new_ir=True) + self.check_grad(*check_args, **check_kwargs, check_pir=True) else: check_args.insert(0, self.place) self.check_grad_with_place( - *check_args, **check_kwargs, check_new_ir=True + *check_args, **check_kwargs, check_pir=True ) @@ -454,15 +455,13 @@ def test_check_gradient(self): 'max_relative_error': max_relative_error, } if self.place is None: - self.check_grad( - *check_args, **check_kwargs, check_new_ir=True - ) + self.check_grad(*check_args, **check_kwargs, check_pir=True) else: check_args.insert(0, self.place) self.check_grad_with_place( *check_args, **check_kwargs, - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True ) @@ -490,6 +489,7 @@ def test_check_gradient(self): class TestElementwiseDivBroadcast(unittest.TestCase): + @test_with_pir_api def test_shape_with_batch_sizes(self): paddle.enable_static() with base.program_guard(base.Program()): @@ -514,6 +514,17 @@ def test_name(self): y_1 = paddle.divide(x, y, name='div_res') self.assertEqual(('div_res' in y_1.name), True) + + with paddle.pir_utils.IrGuard(), base.program_guard(base.Program()): + x = paddle.static.data(name="x", shape=[2, 3], dtype="float32") + y = paddle.static.data(name='y', shape=[2, 3], dtype='float32') + + y_1 = paddle.divide(x, y, name='div_res') + + def name_call(): + self.assertEqual(('div_res' in y_1.name), True) + + self.assertRaises(ValueError, name_call) paddle.disable_static() def test_dygraph(self): @@ -556,7 +567,7 @@ def init_input_output(self): self.out = self.x / self.y def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -564,7 +575,7 @@ def test_check_grad_normal(self): 'Out', numeric_grad_delta=1e-5, max_relative_error=1e-6, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -574,7 +585,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), numeric_grad_delta=1e-5, max_relative_error=1e-6, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -584,7 +595,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), numeric_grad_delta=1e-5, max_relative_error=1e-6, - check_new_ir=True, + check_pir=True, ) diff --git a/test/legacy_test/test_elementwise_mod_op.py b/test/legacy_test/test_elementwise_mod_op.py index bb9348b358ebdf..ba6a75c9e6ac87 100644 --- a/test/legacy_test/test_elementwise_mod_op.py +++ b/test/legacy_test/test_elementwise_mod_op.py @@ -45,9 +45,9 @@ def setUp(self): def test_check_output(self): if self.attrs['axis'] == -1: - self.check_output() + self.check_output(check_pir=True) else: - self.check_output() + self.check_output(check_pir=True) def init_input_output(self): self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype) @@ -102,9 +102,9 @@ def init_input_output(self): def test_check_output(self): if self.attrs['axis'] == -1: - self.check_output() + self.check_output(check_pir=True) else: - self.check_output() + self.check_output(check_pir=True) @unittest.skipIf( @@ -121,9 +121,9 @@ def init_input_output(self): def test_check_output(self): if self.attrs['axis'] == -1: - self.check_output() + self.check_output(check_pir=True) else: - self.check_output() + self.check_output(check_pir=True) class TestElementwiseModFP16Op_ZeroDim1(TestElementwiseModFP16Op): @@ -181,7 +181,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_pir=True) def init_dtype(self): self.dtype = np.uint16 diff --git a/test/legacy_test/test_elementwise_mul_op.py b/test/legacy_test/test_elementwise_mul_op.py index b5a4689c2d40de..0787bf4f5104ae 100644 --- a/test/legacy_test/test_elementwise_mul_op.py +++ b/test/legacy_test/test_elementwise_mul_op.py @@ -49,7 +49,7 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode self.check_output( check_dygraph=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) def test_check_grad_normal(self): @@ -60,7 +60,7 @@ def test_check_grad_normal(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_prim_pir=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) def test_check_grad_ingore_x(self): @@ -72,7 +72,7 @@ def test_check_grad_ingore_x(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_prim_pir=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) def test_check_grad_ingore_y(self): @@ -84,7 +84,7 @@ def test_check_grad_ingore_y(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_prim_pir=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) def init_input_output(self): @@ -132,13 +132,13 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_new_ir=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_new_ir=True) + self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=True) def test_check_grad_ingore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_new_ir=True) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp): @@ -189,7 +189,7 @@ def setUp(self): self.if_enable_cinn() def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -197,7 +197,7 @@ def test_check_grad_normal(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -207,7 +207,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -217,7 +217,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def if_enable_cinn(self): @@ -274,7 +274,7 @@ def setUp(self): def test_check_output(self): self.check_output( check_dygraph=self.check_dygraph, - check_new_ir=self.check_dygraph, + check_pir=self.check_dygraph, ) def test_check_grad_normal(self): @@ -283,7 +283,7 @@ def test_check_grad_normal(self): 'Out', check_dygraph=self.check_dygraph, check_prim=self.check_prim, - check_new_ir=self.check_dygraph, + check_pir=self.check_dygraph, ) def test_check_grad_ingore_x(self): @@ -293,7 +293,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_dygraph=self.check_dygraph, check_prim=self.check_prim, - check_new_ir=self.check_dygraph, + check_pir=self.check_dygraph, ) def test_check_grad_ingore_y(self): @@ -303,7 +303,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_dygraph=self.check_dygraph, check_prim=self.check_prim, - check_new_ir=self.check_dygraph, + check_pir=self.check_dygraph, ) def init_input_attr_output(self): @@ -432,7 +432,7 @@ def test_check_grad_normal(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_prim_pir=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) def test_check_grad_ingore_x(self): @@ -444,7 +444,7 @@ def test_check_grad_ingore_x(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_prim_pir=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) def test_check_grad_ingore_y(self): @@ -456,7 +456,7 @@ def test_check_grad_ingore_y(self): check_dygraph=(not self.use_mkldnn), check_prim=True, check_prim_pir=(not self.use_mkldnn), - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), ) @@ -535,16 +535,16 @@ def init_input_output(self): self.out = self.x * self.y def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', check_new_ir=True) + self.check_grad(['X', 'Y'], 'Out', check_pir=True) def test_check_grad_ingore_x(self): - self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_new_ir=True) + self.check_grad(['Y'], 'Out', no_grad_set=set("X"), check_pir=True) def test_check_grad_ingore_y(self): - self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_new_ir=True) + self.check_grad(['X'], 'Out', no_grad_set=set('Y'), check_pir=True) class TestRealComplexElementwiseMulOp(TestComplexElementwiseMulOp): diff --git a/test/legacy_test/test_elementwise_pow_op.py b/test/legacy_test/test_elementwise_pow_op.py index c83676f686d7a0..82d4f889b28a15 100644 --- a/test/legacy_test/test_elementwise_pow_op.py +++ b/test/legacy_test/test_elementwise_pow_op.py @@ -44,7 +44,7 @@ def test_check_output(self): if hasattr(self, 'attrs'): self.check_output(check_dygraph=False) else: - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): if hasattr(self, 'attrs'): @@ -57,7 +57,7 @@ def test_check_grad_normal(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -204,7 +204,7 @@ def test_check_output(self): if hasattr(self, 'attrs'): self.check_output(check_dygraph=False) else: - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestElementwisePowGradOpInt(unittest.TestCase): @@ -260,7 +260,7 @@ def test_check_output(self): if hasattr(self, 'attrs'): self.check_output(check_dygraph=False) else: - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -271,7 +271,7 @@ def test_check_grad(self): ), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -297,7 +297,7 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(out)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad(['X', 'Y'], 'Out') diff --git a/test/legacy_test/test_elementwise_sub_op.py b/test/legacy_test/test_elementwise_sub_op.py index 9058c6e79e2b72..29185c1844bf4d 100644 --- a/test/legacy_test/test_elementwise_sub_op.py +++ b/test/legacy_test/test_elementwise_sub_op.py @@ -23,6 +23,7 @@ from paddle import base from paddle.base import core from paddle.base.layer_helper import LayerHelper +from paddle.pir_utils import test_with_pir_api class TestElementwiseOp(OpTest): @@ -44,7 +45,7 @@ def init_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -52,7 +53,7 @@ def test_check_grad_normal(self): 'Out', check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_x(self): @@ -63,7 +64,7 @@ def test_check_grad_ingore_x(self): no_grad_set=set("X"), check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -74,7 +75,7 @@ def test_check_grad_ingore_y(self): no_grad_set=set('Y'), check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=True, + check_pir=True, ) def if_check_prim(self): @@ -134,7 +135,7 @@ def test_check_grad_ingore_x(self): max_relative_error=0.1, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_ingore_y(self): @@ -147,7 +148,7 @@ def test_check_grad_ingore_y(self): max_relative_error=0.1, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -392,12 +393,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_dygraph=False, check_new_ir=False) + self.check_output(check_dygraph=False, check_pir=False) def test_check_grad_normal(self): - self.check_grad( - ['X', 'Y'], 'Out', check_dygraph=False, check_new_ir=False - ) + self.check_grad(['X', 'Y'], 'Out', check_dygraph=False, check_pir=False) def test_check_grad_ingore_x(self): self.check_grad( @@ -406,7 +405,7 @@ def test_check_grad_ingore_x(self): max_relative_error=0.005, no_grad_set=set("X"), check_dygraph=False, - check_new_ir=False, + check_pir=False, ) def test_check_grad_ingore_y(self): @@ -416,7 +415,7 @@ def test_check_grad_ingore_y(self): max_relative_error=0.005, no_grad_set=set('Y'), check_dygraph=False, - check_new_ir=False, + check_pir=False, ) @@ -452,13 +451,13 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_dygraph=False, check_new_ir=False + place, check_dygraph=False, check_pir=False ) def test_check_grad_normal(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X', 'Y'], 'Out', check_dygraph=False, check_new_ir=False + place, ['X', 'Y'], 'Out', check_dygraph=False, check_pir=False ) def test_check_grad_ingore_x(self): @@ -469,7 +468,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), check_dygraph=False, - check_new_ir=False, + check_pir=False, ) def test_check_grad_ingore_y(self): @@ -480,7 +479,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set('Y'), check_dygraph=False, - check_new_ir=False, + check_pir=False, ) @@ -846,11 +845,11 @@ def init_input_output(self): self.out = self.x - self.y def test_check_output(self): - self.check_output(check_new_ir=False) + self.check_output(check_pir=False) def test_check_grad_normal(self): self.check_grad( - ['X', 'Y'], 'Out', check_prim=self.check_prim, check_new_ir=False + ['X', 'Y'], 'Out', check_prim=self.check_prim, check_pir=False ) def test_check_grad_ingore_x(self): @@ -859,7 +858,7 @@ def test_check_grad_ingore_x(self): 'Out', no_grad_set=set("X"), check_prim=self.check_prim, - check_new_ir=False, + check_pir=False, ) def test_check_grad_ingore_y(self): @@ -868,7 +867,7 @@ def test_check_grad_ingore_y(self): 'Out', no_grad_set=set('Y'), check_prim=self.check_prim, - check_new_ir=False, + check_pir=False, ) def if_enable_cinn(self): @@ -905,8 +904,9 @@ def test_name(self): y_1 = self._executed_api(x, y, name='subtract_res') self.assertEqual(('subtract_res' in y_1.name), True) + @test_with_pir_api def test_declarative(self): - with base.program_guard(base.Program()): + with paddle.static.program_guard(paddle.static.Program()): def gen_data(): return { @@ -919,7 +919,10 @@ def gen_data(): z = self._executed_api(x, y) place = base.CPUPlace() exe = base.Executor(place) - z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) + if paddle.framework.in_pir_mode(): + z_value = exe.run(feed=gen_data(), fetch_list=[z]) + else: + z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) z_expected = np.array([1.0, -2.0, 2.0]) self.assertEqual((z_value == z_expected).all(), True) diff --git a/test/legacy_test/test_empty_op.py b/test/legacy_test/test_empty_op.py index 44e1f2fe30fb62..a49489417878ac 100644 --- a/test/legacy_test/test_empty_op.py +++ b/test/legacy_test/test_empty_op.py @@ -31,7 +31,7 @@ def setUp(self): self.init_config() def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): data_type = outs[0].dtype @@ -121,7 +121,7 @@ def init_config(self): self.outputs = {'Out': np.zeros(self.shape).astype(dtype)} def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): data_type = outs[0].dtype @@ -172,7 +172,7 @@ def init_config(self): self.outputs = {'Out': np.zeros(self.shape).astype(dtype)} def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): data_type = outs[0].dtype @@ -312,7 +312,7 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(output)} def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): max_value = np.nanmax(outs[0]) diff --git a/test/legacy_test/test_erf_op.py b/test/legacy_test/test_erf_op.py index 24f32175151d65..d66cdc3ce11793 100644 --- a/test/legacy_test/test_erf_op.py +++ b/test/legacy_test/test_erf_op.py @@ -44,10 +44,18 @@ def _init_dtype(self): return "float64" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) + + def test_check_grad_prim_pir(self): + # Todo(CZ): float64 loss greater than 1e-8 + if self.dtype == "float64": + self.dtype = "float32" + self.rev_comp_atol = 1e-7 + self.rev_comp_rtol = 1e-7 + self.check_grad(['X'], 'Out', check_prim_pir=True) class TestErfOp_ZeroDim(TestErfOp): @@ -93,10 +101,16 @@ def setUp(self): self.outputs = {'Out': y_ref} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) @unittest.skipIf( @@ -121,12 +135,17 @@ def setUp(self): def test_check_output(self): place = paddle.base.core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = paddle.base.core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_expand_v2_op.py b/test/legacy_test/test_expand_v2_op.py index f7ba37fb60cbb4..988043d472e252 100644 --- a/test/legacy_test/test_expand_v2_op.py +++ b/test/legacy_test/test_expand_v2_op.py @@ -47,10 +47,16 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestExpandV2OpRank1_ZeroDim1(TestExpandV2OpRank1): @@ -130,10 +136,10 @@ def init_data(self): self.infer_expand_shape = [-1] def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_cinn=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_cinn=True, check_pir=True) class TestExpandV2OpRank2_Corner_tensor_attr(TestExpandV2OpRank1_tensor_attr): @@ -167,10 +173,10 @@ def init_data(self): self.expand_shape = [2, 100] def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_cinn=True, check_new_ir=True) + self.check_grad(['X'], 'Out', check_cinn=True, check_pir=True) # Situation 4: input x is Integer @@ -188,7 +194,7 @@ def setUp(self): self.outputs = {'Out': output} def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) # Situation 5: input x is Bool @@ -204,7 +210,7 @@ def setUp(self): self.outputs = {'Out': output} def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) # Situation 6: input x is Integer @@ -222,7 +228,7 @@ def setUp(self): self.outputs = {'Out': output} def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) # Situation 7: input x is Float16 @@ -244,7 +250,13 @@ def test_check_output(self): self.check_output(check_cinn=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) # Situation 8: input x is BF16 @@ -268,12 +280,17 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_cinn=True, check_new_ir=True) + self.check_output_with_place(place, check_cinn=True, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -438,7 +455,7 @@ def test_check_output(self): self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_prim_pir=True) class TestExpandV2OpCompRank2_DimExpanding(TestExpandV2CompOpRank1): diff --git a/test/legacy_test/test_exponential_op.py b/test/legacy_test/test_exponential_op.py index de92243084ffbe..1df9276590a0f2 100644 --- a/test/legacy_test/test_exponential_op.py +++ b/test/legacy_test/test_exponential_op.py @@ -37,7 +37,7 @@ def config(self): self.dtype = "float64" def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist1, _ = np.histogram(outs[0], range=(0, 5)) @@ -360,7 +360,7 @@ def config(self): self.dtype = np.float16 def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist1, _ = np.histogram(outs[0], range=(0, 5)) @@ -411,7 +411,7 @@ def config(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place_customized( - checker=self.verify_output, place=place + checker=self.verify_output, place=place, check_pir=True ) def verify_output(self, outs): diff --git a/test/legacy_test/test_fill_any_like_op.py b/test/legacy_test/test_fill_any_like_op.py index ebcbd575384212..a60ab183e36cd8 100644 --- a/test/legacy_test/test_fill_any_like_op.py +++ b/test/legacy_test/test_fill_any_like_op.py @@ -58,7 +58,7 @@ def init(self): pass def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output(check_prim=True, check_pir=True) def if_enable_cinn(self): pass @@ -96,7 +96,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_prim=True, check_new_ir=True) + self.check_output_with_place(place, check_prim=True, check_pir=True) def if_enable_cinn(self): pass diff --git a/test/legacy_test/test_fill_constant_op.py b/test/legacy_test/test_fill_constant_op.py index 9f354b5d992767..7ea153d627cbdb 100644 --- a/test/legacy_test/test_fill_constant_op.py +++ b/test/legacy_test/test_fill_constant_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api def fill_wrapper(shape, value=0.0): @@ -44,7 +45,7 @@ def setUp(self): self.outputs = {'Out': np.full(self.shape, self.value)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def init_dtype(self): self.dtype = np.float64 @@ -115,7 +116,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) class TestFillConstantOpWithSelectedRows(unittest.TestCase): @@ -168,7 +169,7 @@ def init_data(self): self.value = 3.8 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestFillConstantOp2_ShapeTensorList(OpTest): @@ -192,7 +193,7 @@ def init_data(self): self.infer_shape = [-1, -1] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestFillConstantOp3_ShapeTensorList(TestFillConstantOp1_ShapeTensorList): @@ -226,7 +227,7 @@ def init_data(self): self.value = 3.8 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) # Situation 4: value is a tensor @@ -250,7 +251,7 @@ def init_data(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) # Situation 5: value is a tensor @@ -274,12 +275,14 @@ def init_data(self): self.dtype = np.int32 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) # Test python API class TestFillConstantAPI(unittest.TestCase): + @test_with_pir_api def test_api(self): + paddle.enable_static() positive_2_int32 = paddle.tensor.fill_constant([1], "int32", 2) positive_2_int64 = paddle.tensor.fill_constant([1], "int64", 2) @@ -330,7 +333,7 @@ def test_api(self): exe = base.Executor(place=base.CPUPlace()) res_1, res_2, res_3, res_4, res_5, res_6, res_7, res_8 = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={ "shape_tensor_int32": np.array([1, 2]).astype("int32"), "shape_tensor_int64": np.array([1, 2]).astype("int64"), @@ -487,6 +490,58 @@ def test_shape_tensor_list_dtype(): self.assertRaises(TypeError, test_shape_tensor_list_dtype) + with paddle.pir_utils.IrGuard(), program_guard(Program()): + x1 = paddle.static.data(name='x1', shape=[-1, 1], dtype="int16") + self.assertRaises( + TypeError, + paddle.tensor.fill_constant, + shape=[1], + value=5, + dtype='uint4', + ) + + self.assertRaises( + ValueError, + paddle.tensor.fill_constant, + shape=[1.1], + value=5, + dtype='float32', + out=x1, + ) + + x3 = np.random.randn(100, 100).astype('int32') + self.assertRaises( + ValueError, + paddle.tensor.fill_constant, + shape=[100, 100], + value=5, + dtype='float64', + out=x3, + ) + + def test_pir_errors(self): + def test_shape_type(): + # The shape dtype of fill_constant_op must be int32 or int64. + # test_shape_tensor_dtype: + with paddle.pir_utils.IrGuard(): + new_ir_program = paddle.static.Program() + with paddle.static.program_guard(new_ir_program): + shape = paddle.static.data( + name="shape_tensor", shape=[2], dtype="int32" + ) + out = paddle.tensor.fill_constant( + shape=shape, dtype="float32", value=1 + ) + exe = base.Executor(place=base.CPUPlace()) + exe.run( + feed={ + "shape_tensor": np.array([1, 2]).astype("float32") + }, + fetch_list=[out], + ) + + self.assertRaises(ValueError, test_shape_type) + class TestFillConstantOp_ValueTensorBf16(OpTest): def setUp(self): @@ -513,7 +568,7 @@ def init_data(self): def test_check_output(self): # no dynamic graph test for mkldnn self.check_output_with_place( - core.CPUPlace(), check_dygraph=False, check_new_ir=False + core.CPUPlace(), check_dygraph=False, check_pir=False ) diff --git a/test/legacy_test/test_flatten_contiguous_range_op.py b/test/legacy_test/test_flatten_contiguous_range_op.py index 71e39e92b8c5a0..83354d87b705bf 100644 --- a/test/legacy_test/test_flatten_contiguous_range_op.py +++ b/test/legacy_test/test_flatten_contiguous_range_op.py @@ -49,11 +49,11 @@ def test_check_output(self): core.CUDAPlace(0), no_check_set=["XShape"], check_prim=True, - check_new_ir=True, + check_pir=True, ) else: self.check_output( - no_check_set=["XShape"], check_prim=True, check_new_ir=True + no_check_set=["XShape"], check_prim=True, check_pir=True ) def test_check_grad(self): @@ -63,10 +63,10 @@ def test_check_grad(self): ["X"], "Out", check_prim=True, - check_new_ir=True, + check_pir=True, ) else: - self.check_grad(["X"], "Out", check_prim=True, check_new_ir=True) + self.check_grad(["X"], "Out", check_prim=True, check_pir=True) def init_test_case(self): self.in_shape = (3, 2, 5, 4) diff --git a/test/legacy_test/test_full_like_op.py b/test/legacy_test/test_full_like_op.py index 137e536126bb46..5cbcc3f5c78aa1 100644 --- a/test/legacy_test/test_full_like_op.py +++ b/test/legacy_test/test_full_like_op.py @@ -148,7 +148,7 @@ def init_data(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=True) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def if_enable_cinn(self): pass diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index 5be92f6f9b7056..d201b9d76e8d3a 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -31,7 +31,7 @@ def deal_qkv(init_q, init_k, init_v): def mult_qkv(value, cos_tensor, sin_tensor): rotate_half_q = paddle.reshape( - paddle.stack([value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1), + paddle.stack([-value[:, :, :, 1::2], value[:, :, :, 0::2]], axis=-1), paddle.shape(value), ) query = paddle.add( @@ -59,7 +59,7 @@ def mult_qkv_rotate_half(value, cos_tensor, sin_tensor): return query -def get_sin_cos_tensor(seq_len, head_dim, sign): +def get_sin_cos_tensor(seq_len, head_dim, sign=1): pos_seq = paddle.arange(0, seq_len, 1, dtype="float32") indices = paddle.arange(0, head_dim, 2, dtype="float32") @@ -93,15 +93,18 @@ def get_sin_cos_tensor(seq_len, head_dim, sign): def paddle_fused_rotary_position_embedding( - init_q, init_k, init_v, position_ids=None, use_neox_rotary_style=True + init_q, + init_k, + init_v, + sin_tensor=None, + cos_tensor=None, + position_ids=None, + use_neox_rotary_style=True, ): # permute q, k, v from [batch_size, seq_len, num_heads, head_dim] # to [batch_size, num_heads, seq_len, head_dim] q, k, v = deal_qkv(init_q, init_k, init_v) - sign = -1 if use_neox_rotary_style else 1 - sin_tensor, cos_tensor = get_sin_cos_tensor(q.shape[2], q.shape[3], sign) - if position_ids is not None: sin_tensor = sin_tensor.squeeze(axis=[0, 2]) # [seq_len, dim] cos_tensor = cos_tensor.squeeze(axis=[0, 2]) # [seq_len, dim] @@ -146,60 +149,45 @@ def get_paddle_tensor(self): tmp.stop_gradient = False return tmp + def get_inputs(self, seed, with_sin_cos): + paddle.disable_static() + paddle.seed(seed) + tensor_q = self.get_paddle_tensor() + tensor_k = self.get_paddle_tensor() + tensor_v = self.get_paddle_tensor() + + tensor_sin, tensor_cos = ( + get_sin_cos_tensor(tensor_q.shape[1], tensor_q.shape[3], 1) + if with_sin_cos + else (None, None) + ) + return tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos + def get_forward_backward( self, rope_function, seed, - flag=False, + with_sin_cos=True, use_neox_rotary_style=True, position_ids=None, ): paddle.disable_static() - paddle.seed(seed) fw = [] bw = [] - tensor_q = self.get_paddle_tensor() - tensor_k = self.get_paddle_tensor() - tensor_v = self.get_paddle_tensor() - if use_neox_rotary_style: - if flag: - tensor_sin, tensor_cos = get_sin_cos_tensor( - tensor_q.shape[1], tensor_q.shape[3], 1 - ) - out_q, out_k, out_v = rope_function( - tensor_q, - tensor_k, - tensor_v, - tensor_sin, - tensor_cos, - position_ids=position_ids, - ) - else: - out_q, out_k, out_v = rope_function( - tensor_q, tensor_k, tensor_v, position_ids=position_ids - ) - else: - if flag: - tensor_sin, tensor_cos = get_sin_cos_tensor( - tensor_q.shape[1], tensor_q.shape[3], 1 - ) - out_q, out_k, out_v = rope_function( - tensor_q, - tensor_k, - tensor_v, - tensor_sin, - tensor_cos, - position_ids=position_ids, - use_neox_rotary_style=False, - ) - else: - out_q, out_k, out_v = rope_function( - tensor_q, - tensor_k, - tensor_v, - position_ids=position_ids, - use_neox_rotary_style=False, - ) + + tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( + seed, with_sin_cos + ) + + out_q, out_k, out_v = rope_function( + tensor_q, + tensor_k, + tensor_v, + tensor_sin, + tensor_cos, + position_ids=position_ids, + use_neox_rotary_style=use_neox_rotary_style, + ) fw.append(out_q) fw.append(out_k) @@ -208,6 +196,7 @@ def get_forward_backward( out_gq = paddle.randn(out_q.shape, self.dtype) out_gk = paddle.randn(out_q.shape, self.dtype) out_gv = paddle.randn(out_q.shape, self.dtype) + paddle.autograd.backward( [out_q, out_k, out_v], [out_gq, out_gk, out_gv], True ) @@ -234,10 +223,14 @@ def test_fused_rope(self): def test_fused_rope_with_sin_cos(self): p_fw, p_bw = self.get_forward_backward( - paddle_fused_rotary_position_embedding, seed=self.seed + paddle_fused_rotary_position_embedding, + seed=self.seed, + with_sin_cos=True, ) f_fw, f_bw = self.get_forward_backward( - fused_rotary_position_embedding, seed=self.seed, flag=True + fused_rotary_position_embedding, + seed=self.seed, + with_sin_cos=True, ) for i in range(len(p_fw)): np.testing.assert_allclose( @@ -278,7 +271,6 @@ def test_fused_rope_position_ids(self): f_fw, f_bw = self.get_forward_backward( fused_rotary_position_embedding, seed=self.seed, - flag=True, position_ids=position_ids, ) for i in range(len(p_fw)): @@ -289,13 +281,59 @@ def test_fused_rope_position_ids(self): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) - def test_error(self): + def test_static(self): + tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( + self.seed, True + ) + p_fw, p_bw = self.get_forward_backward( + paddle_fused_rotary_position_embedding, + seed=self.seed, + use_neox_rotary_style=False, + ) + paddle.enable_static() - with self.assertRaises(RuntimeError): - static_q = paddle.static.data( - name="q", shape=self.shape, dtype=self.dtype - ) - fused_rotary_position_embedding(static_q, static_q, static_q) + + q = paddle.static.data(name="q", shape=self.shape, dtype=self.dtype) + k = paddle.static.data(name="k", shape=self.shape, dtype=self.dtype) + v = paddle.static.data(name="v", shape=self.shape, dtype=self.dtype) + sin = paddle.static.data( + name="sin", + shape=(1, tensor_q.shape[1], 1, tensor_q.shape[3]), + dtype=self.dtype, + ) + cos = paddle.static.data( + name="cos", + shape=(1, tensor_q.shape[1], 1, tensor_q.shape[3]), + dtype=self.dtype, + ) + + out_q, out_k, out_v = fused_rotary_position_embedding( + q, + k, + v, + sin, + cos, + position_ids=None, + use_neox_rotary_style=False, + ) + + exe = paddle.static.Executor() + + feed = { + 'q': tensor_q.numpy(), + 'k': tensor_k.numpy(), + 'v': tensor_v.numpy(), + 'sin': tensor_sin.numpy(), + 'cos': tensor_cos.numpy(), + } + outs = exe.run( + paddle.static.default_main_program(), + feed=feed, + fetch_list=[out_q, out_k, out_v], + ) + + for i in range(3): + np.testing.assert_allclose(p_fw[i].numpy(), outs[i], rtol=1e-05) paddle.disable_static() diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index a10faff2ac1f35..3a27faf99cb6b8 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -53,10 +53,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithEmptyIndexFP16(TestGatherNdOpWithEmptyIndex): @@ -75,12 +81,17 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -114,10 +125,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithIndex1_ZeroDim(TestGatherNdOpWithIndex1): @@ -163,12 +180,17 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -202,10 +224,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithLowIndexFP16(TestGatherNdOpWithLowIndex): @@ -224,7 +252,7 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) @@ -233,8 +261,9 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, numeric_grad_delta=0.5, + check_prim_pir=True, ) @@ -273,15 +302,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, numeric_grad_delta=0.05, + check_prim_pir=True, ) @@ -301,7 +331,7 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) @@ -310,8 +340,9 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, numeric_grad_delta=0.5, + check_prim_pir=True, ) @@ -342,10 +373,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithSameIndexAsXFP16(TestGatherNdOpWithSameIndexAsX): @@ -364,7 +401,7 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) @@ -373,8 +410,9 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, numeric_grad_delta=0.5, + check_prim_pir=True, ) @@ -407,10 +445,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithHighRankSameFP16(TestGatherNdOpWithHighRankSame): @@ -429,12 +473,17 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -468,10 +517,16 @@ def config_dtype(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestGatherNdOpWithHighRankDiffFP16(TestGatherNdOpWithHighRankDiff): @@ -490,12 +545,17 @@ def config_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_gaussian_random_op.py b/test/legacy_test/test_gaussian_random_op.py index 8f03e0f547e8de..2a0f30a84e03c9 100644 --- a/test/legacy_test/test_gaussian_random_op.py +++ b/test/legacy_test/test_gaussian_random_op.py @@ -46,7 +46,7 @@ def set_attrs(self): self.std = 2.0 def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): self.assertEqual(outs[0].shape, (123, 92)) @@ -88,7 +88,7 @@ def set_attrs(self): def test_check_output(self): self.check_output_with_place_customized( - self.verify_output, place=core.CUDAPlace(0), check_new_ir=True + self.verify_output, place=core.CUDAPlace(0), check_pir=True ) def verify_output(self, outs): @@ -141,7 +141,7 @@ def set_attrs(self): def test_check_output(self): self.check_output_with_place_customized( - self.verify_output, place=core.CUDAPlace(0), check_new_ir=True + self.verify_output, place=core.CUDAPlace(0), check_pir=True ) def verify_output(self, outs): @@ -196,7 +196,7 @@ def init_data(self): self.seed = 10 def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) class TestGaussianRandomOp2_ShapeTensorList( diff --git a/test/legacy_test/test_gumbel_softmax_op.py b/test/legacy_test/test_gumbel_softmax_op.py index e3fbf15a299d8c..97751840e687e4 100644 --- a/test/legacy_test/test_gumbel_softmax_op.py +++ b/test/legacy_test/test_gumbel_softmax_op.py @@ -46,10 +46,10 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_new_ir=True) + self.check_grad(["X"], "Out", check_pir=True) class TestGumbelSoftmax_ZeroDim(OpTest): @@ -68,10 +68,10 @@ def setUp(self): self.attrs = {"hard": True, "axis": -1} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_new_ir=True) + self.check_grad(["X"], "Out", check_pir=True) class TestGumbelSoftmaxOp2(TestGumbelSoftmaxOp): @@ -176,7 +176,7 @@ def setUp(self): self.outputs = {'Out': out} def test_check_output(self): - self.check_output_customized(self.accumulate_output, check_new_ir=True) + self.check_output_customized(self.accumulate_output, check_pir=True) # Experiment should result in batch num . self.assertEqual(self.counts.sum(), self.shape[0]) @@ -192,7 +192,7 @@ def test_check_output(self): self.assertLess(np.max(np.abs(z)).item(), 2.58) def test_check_grad(self): - self.check_grad(["X"], "Out", check_new_ir=True) + self.check_grad(["X"], "Out", check_pir=True) class TestGumbelSoftmaxOpGrad(unittest.TestCase): diff --git a/test/legacy_test/test_hypot.py b/test/legacy_test/test_hypot.py new file mode 100644 index 00000000000000..66a049038eb5ae --- /dev/null +++ b/test/legacy_test/test_hypot.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base +from paddle.base import core + +paddle.enable_static() + + +class TestHypotAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [10, 10] + self.y_shape = [10, 1] + self.x_np = np.random.uniform(-10, 10, self.x_shape).astype(np.float32) + self.y_np = np.random.uniform(-10, 10, self.y_shape).astype(np.float32) + + def test_static_graph(self): + paddle.enable_static() + startup_program = base.Program() + train_program = base.Program() + with base.program_guard(startup_program, train_program): + x = paddle.static.data( + name='input1', dtype='float32', shape=self.x_shape + ) + y = paddle.static.data( + name='input2', dtype='float32', shape=self.y_shape + ) + out = paddle.hypot(x, y) + + place = ( + base.CUDAPlace(0) + if core.is_compiled_with_cuda() + else base.CPUPlace() + ) + exe = base.Executor(place) + res = exe.run( + base.default_main_program(), + feed={'input1': self.x_np, 'input2': self.y_np}, + fetch_list=[out], + ) + np_out = np.hypot(self.x_np, self.y_np) + np.testing.assert_allclose(res[0], np_out, atol=1e-5, rtol=1e-5) + paddle.disable_static() + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + result = paddle.hypot(x, y) + np.testing.assert_allclose( + np.hypot(self.x_np, self.y_np), result.numpy(), rtol=1e-05 + ) + + paddle.enable_static() + + def test_error(self): + x = paddle.to_tensor(self.x_np) + y = 3.8 + self.assertRaises(TypeError, paddle.hypot, x, y) + self.assertRaises(TypeError, paddle.hypot, y, x) + + +class TestHypotAPIBroadCast(TestHypotAPI): + def setUp(self): + self.x_np = np.arange(6).astype(np.float32) + self.y_np = np.array([20]).astype(np.float32) + self.x_shape = [6] + self.y_shape = [1] + + +class TestHypotAPI3(TestHypotAPI): + def setUp(self): + self.x_shape = [] + self.y_shape = [] + self.x_np = np.random.uniform(-10, 10, self.x_shape).astype(np.float32) + self.y_np = np.random.uniform(-10, 10, self.y_shape).astype(np.float32) + + +class TestHypotAPI4(TestHypotAPI): + def setUp(self): + self.x_shape = [1] + self.y_shape = [1] + self.x_np = np.random.uniform(-10, 10, self.x_shape).astype(np.float32) + self.y_np = np.random.uniform(-10, 10, self.y_shape).astype(np.float32) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index e3f1de1048e113..5be252c3779546 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -834,6 +834,59 @@ def test_error(self): self.assertRaises(ValueError, paddle.gcd_, x, y) +class TestDygraphInplaceHypot(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.randint(2, size=200) + self.input_var_numpy = self.input_var_numpy.reshape([10, 20]) + self.dtype = "float32" + self.y = paddle.randn(shape=[10, 20], dtype="float32") + + def inplace_api_processing(self, var): + return paddle.hypot_(var, self.y) + + def non_inplace_api_processing(self, var): + return paddle.hypot(var, self.y) + + def test_errors(self): + x = 3.0 + self.assertRaises(TypeError, paddle.hypot_, x, self.y) + self.assertRaises(TypeError, paddle.hypot_, self.y, x) + + def test_forward_version(self): + with paddle.base.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 3) + + inplace_var[0] = 2.0 + self.assertEqual(var.inplace_version, 4) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 7) + + def test_backward_error(self): + # It raises an error because the inplace operator will result + # in incorrect gradient computation. + with paddle.base.dygraph.guard(): + var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + var_a.stop_gradient = False + + var_b = var_a**2 + # Here, the gradient computation will use the value of var_b + var_c = var_b**2 + self.inplace_api_processing(var_b) + var_c = paddle.cast(var_c, "float32") + + loss = paddle.nn.functional.relu(var_c) + with self.assertRaisesRegex( + RuntimeError, + f"received tensor_version:{3} != wrapper_version_snapshot:{0}", + ): + loss.backward() + + class TestDygraphInplaceNanToNum(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.array( diff --git a/test/legacy_test/test_layer_norm_op.py b/test/legacy_test/test_layer_norm_op.py index 3fb01bb3d0b62a..cc4726f3458cf4 100644 --- a/test/legacy_test/test_layer_norm_op.py +++ b/test/legacy_test/test_layer_norm_op.py @@ -143,7 +143,7 @@ def test_check_output(self): rtol=self.ori_rtol, check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=self.check_new_ir, + check_pir=self.check_pir, ) def test_check_grad(self): @@ -153,7 +153,7 @@ def test_check_grad(self): max_relative_error=self.max_relative_error, check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=self.check_new_ir, + check_pir=self.check_pir, ) def initConfig(self): @@ -177,7 +177,7 @@ def initConfig(self): self.has_bias = True self.check_prim = True self.check_prim_pir = True - self.check_new_ir = True + self.check_pir = True def initTestCase(self): np.random.seed(123) @@ -247,7 +247,7 @@ def test_check_output(self): rtol=self.ori_rtol, check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=self.check_new_ir, + check_pir=self.check_pir, ) def test_check_grad(self): @@ -258,7 +258,7 @@ def test_check_grad(self): max_relative_error=self.max_relative_error, check_prim=self.check_prim, check_prim_pir=self.check_prim_pir, - check_new_ir=self.check_new_ir, + check_pir=self.check_pir, ) def initConfig(self): @@ -275,7 +275,7 @@ def initConfig(self): self.has_bias = True self.check_prim = True self.check_prim_pir = True - self.check_new_ir = True + self.check_pir = True def initTestCase(self): np.random.seed(123) @@ -347,7 +347,7 @@ def initConfig(self): self.has_bias = False self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True @unittest.skipIf( @@ -371,7 +371,7 @@ def initConfig(self): self.has_bias = False self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True @unittest.skipIf( @@ -400,7 +400,7 @@ def initConfig(self): self.has_bias = False self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True @unittest.skipIf( @@ -424,7 +424,7 @@ def initConfig(self): self.has_bias = False self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True @unittest.skipIf( @@ -453,7 +453,7 @@ def initConfig(self): self.has_bias = True self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True @unittest.skipIf( @@ -477,7 +477,7 @@ def initConfig(self): self.has_bias = True self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest): @@ -497,7 +497,7 @@ def initConfig(self): self.has_bias = True self.check_prim = True self.check_prim_pir = True - self.check_new_ir = True + self.check_pir = True class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest): @@ -517,7 +517,7 @@ def initConfig(self): self.has_bias = False self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest): @@ -537,7 +537,7 @@ def initConfig(self): self.has_bias = False self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest): @@ -557,7 +557,7 @@ def initConfig(self): self.has_bias = True self.check_prim = False self.check_prim_pir = False - self.check_new_ir = True + self.check_pir = True class TestLayerNormOp(unittest.TestCase): @@ -838,6 +838,11 @@ def test_errors(self): name='x2', shape=[-1, 3, 32, 32], dtype="int32" ) self.assertRaises(TypeError, layer_norm, x2) + with paddle.pir_utils.IrGuard(), program_guard(Program(), Program()): + layer_norm = paddle.nn.LayerNorm([32, 32]) + # the input of LayerNorm must be Variable. + x1 = np.random.random((3, 32, 32)).astype('float32') + self.assertRaises(ValueError, layer_norm, x1) @unittest.skipIf( diff --git a/test/legacy_test/test_logcumsumexp_op.py b/test/legacy_test/test_logcumsumexp_op.py index 373548f679b88b..0be9a6f4d450b9 100644 --- a/test/legacy_test/test_logcumsumexp_op.py +++ b/test/legacy_test/test_logcumsumexp_op.py @@ -232,7 +232,7 @@ def setUp(self): self.outputs = {'Out': np_logcumsumexp(input, **attrs)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -245,6 +245,7 @@ def test_check_grad(self): **self.attrs ) ], + check_pir=True, ) def input_and_attrs(self): @@ -332,7 +333,7 @@ def test_check_output(self): place = core.CUDAPlace(0) place = core.CUDAPlace(0) self.check_output_with_place_customized( - checker=self.verify_output, place=place + checker=self.verify_output, place=place, check_pir=True ) def verify_output(self, outs): @@ -352,7 +353,12 @@ def verify_output(self, outs): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', numeric_grad_delta=0.5, max_relative_error=0.5 + place, + ['X'], + 'Out', + numeric_grad_delta=0.5, + max_relative_error=0.5, + check_pir=True, ) diff --git a/test/legacy_test/test_lookup_table_v2_op.py b/test/legacy_test/test_lookup_table_v2_op.py index 035aef9f7576c9..ad708eb137bb1f 100644 --- a/test/legacy_test/test_lookup_table_v2_op.py +++ b/test/legacy_test/test_lookup_table_v2_op.py @@ -62,7 +62,7 @@ def id_dtype(self): return "int64" def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): self.check_grad( @@ -70,7 +70,7 @@ def test_check_grad(self): 'Out', no_grad_set=set('Ids'), check_cinn=True, - check_new_ir=True, + check_pir=True, ) @@ -99,7 +99,7 @@ def setUp(self): self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): self.check_grad( @@ -107,7 +107,7 @@ def test_check_grad(self): 'Out', no_grad_set=set('Ids'), check_cinn=True, - check_new_ir=True, + check_pir=True, ) @@ -122,7 +122,7 @@ def test_check_output(self): padding_idx = np.random.choice(ids, 1)[0] self.outputs['Out'][ids == padding_idx] = np.zeros(31) self.attrs = {'padding_idx': int(padding_idx)} - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) @skip_check_grad_ci( @@ -137,7 +137,7 @@ def test_check_output(self): padding_idx = np.random.choice(flatten_idx, 1)[0] self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) self.attrs = {'padding_idx': padding_idx} - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) class TestLookupTableWIsSelectedRows(unittest.TestCase): @@ -355,7 +355,7 @@ def id_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_cinn=True, check_new_ir=True) + self.check_output_with_place(place, check_cinn=True, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) @@ -365,7 +365,7 @@ def test_check_grad(self): 'Out', no_grad_set=set('Ids'), check_cinn=True, - check_new_ir=True, + check_pir=True, ) diff --git a/test/legacy_test/test_lr_scheduler.py b/test/legacy_test/test_lr_scheduler.py index 54484ecc6ad2c2..ba1f712dce2fd8 100644 --- a/test/legacy_test/test_lr_scheduler.py +++ b/test/legacy_test/test_lr_scheduler.py @@ -464,6 +464,31 @@ def exp_range(x): return base_learning_rate + base_height * scale_fn(eval(scale_mode)) +linear_last_lr = None + + +def linear_lr( + epoch_num, + learning_rate, + total_steps, + start_factor=1.0 / 3, + end_factor=1.0, + verbose=False, +): + global linear_last_lr + if epoch_num == 0: + linear_last_lr = learning_rate * start_factor + return linear_last_lr + elif epoch_num > total_steps: + return linear_last_lr + else: + base_lr = total_steps * start_factor + cur_factor = end_factor - start_factor + factor = 1.0 + cur_factor / (base_lr + (epoch_num - 1) * cur_factor) + linear_last_lr *= factor + return linear_last_lr + + class TestLRScheduler(unittest.TestCase): def _test_static(self, python_func, paddle_api, kwarg, place): scheduler = paddle_api(**kwarg) @@ -711,6 +736,19 @@ def test_scheduler(self): paddle.optimizer.lr.PiecewiseDecay( boundaries=[100, 200], values=[0.5, 0.1] ) + # check minus total_steps + with self.assertRaises(ValueError): + paddle.optimizer.lr.LinearLR(learning_rate=1, total_steps=-1) + # check start_factor + with self.assertRaises(ValueError): + paddle.optimizer.lr.LinearLR( + learning_rate=1, total_steps=5, start_factor=2 + ) + # check end_factor + with self.assertRaises(ValueError): + paddle.optimizer.lr.LinearLR( + learning_rate=1, total_steps=5, end_factor=2 + ) func_api_kwargs = [ ( @@ -944,6 +982,28 @@ def test_scheduler(self): "verbose": False, }, ), + ( + linear_lr, + paddle.optimizer.lr.LinearLR, + { + "learning_rate": 0.2, + "total_steps": 40, + "start_factor": 0.5, + "end_factor": 1, + "verbose": False, + }, + ), + ( + linear_lr, + paddle.optimizer.lr.LinearLR, + { + "learning_rate": 0.2, + "total_steps": 5, + "start_factor": 0.2, + "end_factor": 0.5, + "verbose": False, + }, + ), ] for python_func, paddle_api, kwarg in func_api_kwargs: diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index e9d2ee096d7dd2..1a0254b66df52b 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -14,6 +14,7 @@ import inspect import unittest +import warnings import paddle @@ -21,6 +22,35 @@ class TestMathOpPatchesPir(unittest.TestCase): + def test_item(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + y = paddle.static.data( + name='y', + shape=[ + 3, + ], + ) + self.assertTrue(y.item() == y) + with self.assertRaises(TypeError): + x.item() + + def test_place(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with paddle.pir_utils.IrGuard(): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + x.place() + self.assertTrue(len(w) == 1) + self.assertTrue("place" in str(w[-1].message)) + + def test_some_dim(self): + with paddle.pir_utils.IrGuard(): + x = paddle.static.data(name='x', shape=[3, 2, 1]) + self.assertEqual(x.dim(), 3) + self.assertEqual(x.ndimension(), 3) + self.assertEqual(x.ndim, 3) + def test_math_exists(self): with paddle.pir_utils.IrGuard(): a = paddle.static.data(name='a', shape=[1], dtype='float32') diff --git a/test/legacy_test/test_matmul_v2_op.py b/test/legacy_test/test_matmul_v2_op.py index 0293e0414a23ea..eb893971e026b2 100644 --- a/test/legacy_test/test_matmul_v2_op.py +++ b/test/legacy_test/test_matmul_v2_op.py @@ -99,7 +99,7 @@ def setUp(self): def test_check_output(self): self.check_output( check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) def test_check_grad(self): @@ -111,7 +111,7 @@ def test_check_grad(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) else: self.check_grad( @@ -120,7 +120,7 @@ def test_check_grad(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) @@ -362,7 +362,7 @@ def test_check_output(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) def test_check_grad(self): @@ -376,7 +376,7 @@ def test_check_grad(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) cls_name = "{}_{}".format(parent.__name__, "Fp16") @@ -436,7 +436,7 @@ def test_check_output(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_x(self): @@ -453,7 +453,7 @@ def test_check_grad_x(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) def test_check_grad_y(self): @@ -470,7 +470,7 @@ def test_check_grad_y(self): check_cinn=self.check_cinn if hasattr(self, 'check_cinn') else True, - check_new_ir=True, + check_pir=True, ) def test_check_grad(self): @@ -745,7 +745,7 @@ def init_input_output(self): self.out = np.matmul(self.x, self.y) def test_check_output(self): - self.check_output(check_cinn=False, check_new_ir=True) + self.check_output(check_cinn=False, check_pir=True) class TestInt32MatMulOpBroadcast(OpTest): @@ -797,7 +797,7 @@ def init_input_output(self): self.out = np.matmul(self.x, self.y) def test_check_output(self): - self.check_output(check_cinn=False, check_new_ir=True) + self.check_output(check_cinn=False, check_pir=True) class TestInt64MatMulOpBroadcast(OpTest): diff --git a/test/legacy_test/test_mean_op.py b/test/legacy_test/test_mean_op.py index ee8cf92ffcb0a5..e217b31d980d65 100644 --- a/test/legacy_test/test_mean_op.py +++ b/test/legacy_test/test_mean_op.py @@ -23,6 +23,7 @@ import paddle from paddle import base from paddle.base import Program, core, program_guard +from paddle.pir_utils import test_with_pir_api np.random.seed(10) @@ -52,10 +53,10 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_checkout_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) class TestMeanOp_ZeroDim(OpTest): @@ -67,18 +68,26 @@ def setUp(self): self.outputs = {'Out': np.mean(self.inputs["X"])} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_checkout_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) class TestMeanOpError(unittest.TestCase): + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.int32) + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + def test_errors(self): paddle.enable_static() with program_guard(Program(), Program()): # The input type of mean_op must be Variable. - input1 = 12 self.assertRaises(TypeError, paddle.mean, input1) # The input dtype of mean_op must be float16, float32, float64. @@ -90,6 +99,20 @@ def test_errors(self): name='input3', shape=[-1, 4], dtype="float16" ) paddle.nn.functional.softmax(input3) + + with paddle.pir_utils.IrGuard(), program_guard(Program(), Program()): + input1 = 12 + self.assertRaises(ValueError, paddle.mean, input1) + + input2 = paddle.static.data( + name='input2', shape=[2, 3, 4, 5], dtype="int32" + ) + + out = paddle.mean(input2) + + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'input2': self.x}, fetch_list=[out]) + paddle.disable_static() @@ -104,7 +127,7 @@ def init_dtype_type(self): def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_checkout_grad(self): place = core.CUDAPlace(0) @@ -128,11 +151,11 @@ def init_dtype_type(self): def test_check_output(self): paddle.enable_static() - self.check_output_with_place(core.CPUPlace(), check_new_ir=True) + self.check_output_with_place(core.CPUPlace(), check_pir=True) def test_checkout_grad(self): place = core.CPUPlace() - self.check_grad_with_place(place, ['X'], 'Out', check_new_ir=True) + self.check_grad_with_place(place, ['X'], 'Out', check_pir=True) def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False): @@ -190,7 +213,7 @@ def if_enable_cinn(self): def test_check_output(self): if self.dtype != 'float16': self.check_output( - check_prim=True, check_prim_pir=True, check_new_ir=True + check_prim=True, check_prim_pir=True, check_pir=True ) else: place = paddle.CUDAPlace(0) @@ -198,7 +221,7 @@ def test_check_output(self): place=place, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad(self): @@ -208,7 +231,7 @@ def test_check_grad(self): ['Out'], check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) else: place = paddle.CUDAPlace(0) @@ -219,7 +242,7 @@ def test_check_grad(self): numeric_grad_delta=0.5, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -446,6 +469,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_api_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_min_op.py b/test/legacy_test/test_min_op.py index e24471b20dca8f..78601c77ecf069 100644 --- a/test/legacy_test/test_min_op.py +++ b/test/legacy_test/test_min_op.py @@ -21,6 +21,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api class ApiMinTest(unittest.TestCase): @@ -30,6 +31,7 @@ def setUp(self): else: self.place = core.CPUPlace() + @test_with_pir_api def test_api(self): paddle.enable_static() with paddle.static.program_guard( diff --git a/test/legacy_test/test_minimum_op.py b/test/legacy_test/test_minimum_op.py index 6267b78b4cf9db..79970ce77f406b 100644 --- a/test/legacy_test/test_minimum_op.py +++ b/test/legacy_test/test_minimum_op.py @@ -18,6 +18,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class ApiMinimumTest(unittest.TestCase): @@ -39,6 +40,7 @@ def setUp(self): self.np_expected3 = np.minimum(self.input_a, self.input_c) self.np_expected4 = np.minimum(self.input_b, self.input_c) + @test_with_pir_api def test_static_api(self): paddle.enable_static() with paddle.static.program_guard( @@ -119,3 +121,7 @@ def test_dynamic_api(self): res = paddle.minimum(b, c) res = res.numpy() np.testing.assert_allclose(res, self.np_expected4, rtol=1e-05) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/legacy_test/test_multinomial_op.py b/test/legacy_test/test_multinomial_op.py index bb4c53fb348217..e886876b27583a 100644 --- a/test/legacy_test/test_multinomial_op.py +++ b/test/legacy_test/test_multinomial_op.py @@ -59,7 +59,7 @@ def init_data(self): self.attrs = {"num_samples": 100000, "replacement": True} def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def sample_output(self, out): return sample_output_one_dimension(out, 4) @@ -122,7 +122,7 @@ def init_data(self): self.attrs = {"num_samples": 100000, "replacement": True} def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def sample_output(self, out): return sample_output_one_dimension(out, 4) @@ -178,6 +178,7 @@ class TestMultinomialBF16OP(OpTest): def setUp(self): paddle.enable_static() self.op_type = "multinomial" + self.python_api = paddle.multinomial self.dtype = np.uint16 self.init_data() self.inputs = {"X": convert_float_to_uint16(self.input_np)} @@ -190,7 +191,9 @@ def init_data(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place_customized(self.verify_output, place) + self.check_output_with_place_customized( + self.verify_output, place, check_pir=True + ) def sample_output(self, out): return sample_output_one_dimension(out, 4) diff --git a/test/legacy_test/test_numel_op.py b/test/legacy_test/test_numel_op.py index 33f1dc7cf4c2cc..32f043dab1b9b5 100644 --- a/test/legacy_test/test_numel_op.py +++ b/test/legacy_test/test_numel_op.py @@ -34,7 +34,7 @@ def setUp(self): self.outputs = {'Out': np.array(np.size(x))} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def init(self): self.shape = (6, 56, 8, 55) @@ -136,7 +136,7 @@ def setUp(self): def test_check_output(self): place = paddle.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def init(self): self.shape = (6, 56, 8, 55) diff --git a/test/legacy_test/test_pad3d_op.py b/test/legacy_test/test_pad3d_op.py index 42efb91a166d17..52c9557766914c 100644 --- a/test/legacy_test/test_pad3d_op.py +++ b/test/legacy_test/test_pad3d_op.py @@ -91,10 +91,10 @@ def setUp(self): self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out']) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) def get_dtype(self): return np.float64 @@ -214,11 +214,11 @@ def get_dtype(self): return np.float16 def test_check_output(self): - self.check_output(atol=1e-3, check_new_ir=True) + self.check_output(atol=1e-3, check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['X'], 'Out', max_relative_error=1.5e-3, check_new_ir=True + ['X'], 'Out', max_relative_error=1.5e-3, check_pir=True ) cls_name = "{}_{}".format(parent.__name__, "FP16OP") @@ -253,12 +253,12 @@ def get_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-2, check_new_ir=True) + self.check_output_with_place(place, atol=1e-2, check_pir=True) def test_check_grad_normal(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', max_relative_error=1e-2, check_new_ir=True + place, ['X'], 'Out', max_relative_error=1e-2, check_pir=True ) cls_name = "{}_{}".format(parent.__name__, "BF16OP") diff --git a/test/legacy_test/test_pad_op.py b/test/legacy_test/test_pad_op.py index 8054d7c75ffb11..81efa838178e8f 100644 --- a/test/legacy_test/test_pad_op.py +++ b/test/legacy_test/test_pad_op.py @@ -21,7 +21,8 @@ from utils import static_guard import paddle -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api def pad_wrapper(x, paddings, pad_value): @@ -57,10 +58,16 @@ def get_dtype(self): return np.float64 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) def initTestCase(self): self.shape = (16, 16) @@ -101,7 +108,13 @@ def get_dtype(self): return np.float16 def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) cls_name = "{}_{}".format(parent.__name__, "Fp16") TestPadFp16.__name__ = cls_name @@ -117,7 +130,9 @@ def test_check_grad_normal(self): class TestPadOpError(unittest.TestCase): def test_errors(self): with static_guard(): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input_data = np.random.random((2, 2)).astype("float32") def test_Variable(): @@ -138,9 +153,9 @@ def init_info(self): def test_static(self): with static_guard(): - main_prog = Program() - starup_prog = Program() - with program_guard(main_prog, starup_prog): + main_prog = paddle.static.Program() + starup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, starup_prog): fc = paddle.nn.Linear(4, 10) x = paddle.randn([2, 4]) x.stop_gradient = False @@ -159,6 +174,7 @@ def test_static(self): res[0], [1, 1], 'constant', constant_values=[1.0, 1.0] ) np.testing.assert_allclose(res[1], gt) + paddle.static.save_inference_model( self.save_path, [x], [feat, out], exe ) @@ -172,6 +188,29 @@ def test_static(self): ) np.testing.assert_allclose(infer_outs[1], gt) + def test_pir_static(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + starup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + + out = self.call_func(feat) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[feat, out]) + gt = np.pad( + res[0], [1, 1], 'constant', constant_values=[1.0, 1.0] + ) + np.testing.assert_allclose(res[1], gt) + def path_prefix(self): return 'padding_value' @@ -196,12 +235,13 @@ def call_func(self, x): class TestPaddingValueTensor3(unittest.TestCase): + @test_with_pir_api def test_static(self): with static_guard(): np_x = np.random.random((16, 16)).astype('float32') - main_prog = Program() - starup_prog = Program() - with program_guard(main_prog, starup_prog): + main_prog = paddle.static.Program() + starup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, starup_prog): x = paddle.assign(np_x).astype('float32') pad_value = paddle.assign([0.0]).astype('float64') y = paddle.nn.functional.pad(x, [0, 1, 2, 3], value=pad_value) @@ -253,12 +293,17 @@ def initTestCase(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_parallel_dygraph_dataparallel.py b/test/legacy_test/test_parallel_dygraph_dataparallel.py index de3160e9c6f9c9..b3cbfbf0966f89 100644 --- a/test/legacy_test/test_parallel_dygraph_dataparallel.py +++ b/test/legacy_test/test_parallel_dygraph_dataparallel.py @@ -121,6 +121,7 @@ def start_local_trainers( "PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint, "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "FLAGS_dynamic_static_unified_comm": "0", } proc_env["FLAGS_allocator_strategy"] = allocator_strategy diff --git a/test/legacy_test/test_poisson_op.py b/test/legacy_test/test_poisson_op.py index 2002b94ac8013a..b2b889645ddfc8 100644 --- a/test/legacy_test/test_poisson_op.py +++ b/test/legacy_test/test_poisson_op.py @@ -63,7 +63,7 @@ def verify_output(self, outs): np.testing.assert_allclose(hist, prob, rtol=0.01) def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -73,6 +73,7 @@ def test_check_grad_normal(self): user_defined_grad_outputs=[ np.random.rand(2048, 1024).astype(self.dtype) ], + check_pir=True, ) @@ -409,7 +410,7 @@ def verify_output(self, outs): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place_customized( - self.verify_output, place, check_new_ir=True + self.verify_output, place, check_pir=True ) def test_check_grad(self): @@ -422,7 +423,7 @@ def test_check_grad(self): user_defined_grad_outputs=[ np.random.rand(2048, 1024).astype("float32") ], - check_new_ir=True, + check_pir=True, ) diff --git a/test/legacy_test/test_pool2d_api.py b/test/legacy_test/test_pool2d_api.py index fcca5381fa4f06..84615340fe051e 100644 --- a/test/legacy_test/test_pool2d_api.py +++ b/test/legacy_test/test_pool2d_api.py @@ -25,6 +25,7 @@ from paddle import base from paddle.base import core from paddle.nn.functional import avg_pool2d, max_pool2d +from paddle.pir_utils import test_with_pir_api class TestPool2D_API(unittest.TestCase): @@ -52,7 +53,7 @@ def check_avg_static_results(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input": input_np}, fetch_list=[result], ) @@ -144,7 +145,7 @@ def check_max_static_results(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"input": input_np}, fetch_list=[result], ) @@ -360,8 +361,6 @@ def test_pool2d(self): for place in self.places: self.check_max_dygraph_results(place) self.check_avg_dygraph_results(place) - self.check_max_static_results(place) - self.check_avg_static_results(place) self.check_max_dygraph_stride_is_none(place) self.check_avg_dygraph_stride_is_none(place) self.check_max_dygraph_padding(place) @@ -370,6 +369,14 @@ def test_pool2d(self): self.check_max_dygraph_ceilmode_results(place) self.check_max_dygraph_nhwc_results(place) + @test_with_pir_api + def test_pool2d_static(self): + paddle.enable_static() + for place in self.places: + self.check_max_static_results(place) + self.check_avg_static_results(place) + paddle.disable_static() + class TestPool2DError_API(unittest.TestCase): def test_error_api(self): diff --git a/test/legacy_test/test_print_op.py b/test/legacy_test/test_print_op.py index 3352d2b23ef937..c4390d76bb9ffd 100755 --- a/test/legacy_test/test_print_op.py +++ b/test/legacy_test/test_print_op.py @@ -97,8 +97,8 @@ def test_errors(self): np.array([[-1]]), [[1]], paddle.CPUPlace() ) self.assertRaises(TypeError, paddle.static.Print, x1) - # The input dtype of Print_op must be float32, float64, int32_t, int64_t or bool. - x2 = paddle.static.data(name='x2', shape=[4], dtype="float16") + # The input dtype of Print_op must be uint16, float16, float32, float64, int32_t, int64_t or bool. + x2 = paddle.static.data(name='x2', shape=[4], dtype="int8") self.assertRaises(TypeError, paddle.static.Print, x2) diff --git a/test/legacy_test/test_randint_op.py b/test/legacy_test/test_randint_op.py index a48750eebdc7d3..fefae2c4d81648 100644 --- a/test/legacy_test/test_randint_op.py +++ b/test/legacy_test/test_randint_op.py @@ -46,7 +46,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -87,7 +87,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -107,7 +107,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) diff --git a/test/legacy_test/test_randperm_op.py b/test/legacy_test/test_randperm_op.py index ceb8b82aa0f55d..9cb270801fece9 100644 --- a/test/legacy_test/test_randperm_op.py +++ b/test/legacy_test/test_randperm_op.py @@ -83,7 +83,7 @@ def init_attrs(self): pass def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): out_np = np.array(outs[0]) @@ -144,7 +144,9 @@ def init_attrs(self): self.np_dtype = np.float32 def test_check_output(self): - self.check_output_with_place_customized(self.verify_output, self.place) + self.check_output_with_place_customized( + self.verify_output, self.place, check_pir=True + ) def verify_output(self, outs): out_np = convert_uint16_to_float(np.array(outs[0])) diff --git a/test/legacy_test/test_reduce_op.py b/test/legacy_test/test_reduce_op.py index d60fb8bfeb1468..a7595cd1331c83 100644 --- a/test/legacy_test/test_reduce_op.py +++ b/test/legacy_test/test_reduce_op.py @@ -22,6 +22,7 @@ from paddle import base from paddle.base import Program, core, program_guard from paddle.base.framework import convert_np_dtype_to_dtype_, in_pir_mode +from paddle.pir_utils import test_with_pir_api class TestSumOp(OpTest): @@ -55,14 +56,14 @@ def calc_output(self): self.out = self.x.sum(axis=tuple(self.attrs['dim'])) def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -95,7 +96,7 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', - check_new_ir=True, + check_pir=True, check_prim=True, check_prim_pir=True, ) @@ -125,10 +126,10 @@ def init_attrs(self): self.attrs = {'dim': (0, 3)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) class TestSumOp_withInt(TestSumOp): @@ -141,7 +142,7 @@ def init_attrs(self): self.attrs = {'dim': (0, 1)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def calc_gradient(self): x = self.inputs["X"] @@ -155,7 +156,7 @@ def test_check_grad(self): user_defined_grads=self.calc_gradient(), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -167,7 +168,7 @@ def init_attrs(self): self.attrs = {'dim': (0, 1, 2)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def calc_gradient(self): x = self.inputs["X"] @@ -181,7 +182,7 @@ def test_check_grad(self): user_defined_grads=self.calc_gradient(), check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -194,7 +195,7 @@ def init_dtype(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -202,7 +203,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -231,7 +232,7 @@ def init_dtype(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_new_ir=True) + self.check_output_with_place(place, check_pir=True) def test_check_grad(self): place = core.CUDAPlace(0) @@ -242,7 +243,7 @@ def test_check_grad(self): user_defined_grads=self.gradient, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def calc_gradient(self): @@ -279,7 +280,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): # only composite op support gradient check of reduce_max @@ -288,7 +289,7 @@ def test_check_grad(self): 'Out', check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -314,7 +315,7 @@ def init_inputs_and_outputs(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): # only composite op support gradient check of reduce_max @@ -323,7 +324,7 @@ def test_check_grad(self): 'Out', check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -368,7 +369,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): # only composite op support gradient check of reduce_max @@ -377,7 +378,7 @@ def test_check_grad(self): 'Out', check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) def init_dtype(self): @@ -403,7 +404,7 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0), check_new_ir=True) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad(self): # only composite op support gradient check of reduce_max @@ -413,7 +414,7 @@ def test_check_grad(self): 'Out', check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -826,7 +827,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllFloatOp(OpTest): @@ -838,7 +839,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllIntOp(OpTest): @@ -850,7 +851,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllOp_ZeroDim(OpTest): @@ -862,7 +863,7 @@ def setUp(self): self.attrs = {'dim': []} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAll8DOp(OpTest): @@ -878,7 +879,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllOpWithDim(OpTest): @@ -890,7 +891,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAll8DOpWithDim(OpTest): @@ -906,7 +907,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].all(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllOpWithKeepDim(OpTest): @@ -920,7 +921,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAll8DOpWithKeepDim(OpTest): @@ -940,7 +941,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) class TestAllOpError(unittest.TestCase): @@ -964,7 +965,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyFloatOp(OpTest): @@ -976,7 +977,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyIntOp(OpTest): @@ -988,7 +989,7 @@ def setUp(self): self.attrs = {'reduce_all': True} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOp_ZeroDim(OpTest): @@ -1000,7 +1001,7 @@ def setUp(self): self.attrs = {'dim': []} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAny8DOp(OpTest): @@ -1016,7 +1017,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOpWithDim(OpTest): @@ -1028,7 +1029,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=1)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAny8DOpWithDim(OpTest): @@ -1044,7 +1045,7 @@ def setUp(self): self.outputs = {'Out': self.inputs['X'].any(axis=self.attrs['dim'])} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOpWithKeepDim(OpTest): @@ -1060,7 +1061,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAny8DOpWithKeepDim(OpTest): @@ -1080,7 +1081,7 @@ def setUp(self): } def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestAnyOpError(unittest.TestCase): @@ -1303,7 +1304,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): # only composite op support gradient check of reduce_max @@ -1312,7 +1313,7 @@ def test_check_grad(self): 'Out', check_prim=True, only_check_prim=True, - check_new_ir=True, + check_pir=True, ) @@ -1615,6 +1616,15 @@ def test_errors(self): x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="uint8") self.assertRaises(TypeError, paddle.sum, x2) + with paddle.pir_utils.IrGuard(), program_guard( + Program(), Program() + ): + # The input type of reduce_sum_op must be Variable. + x1 = base.create_lod_tensor( + np.array([[-1]]), [[1]], base.CPUPlace() + ) + self.assertRaises(ValueError, paddle.sum, x1) + class API_TestSumOp(unittest.TestCase): def run_static( @@ -1645,6 +1655,7 @@ def run_static( rtol=1e-05, ) + @test_with_pir_api def test_static(self): shape = [10, 10] axis = 1 @@ -1803,21 +1814,25 @@ def setUp(self): self.places.append(base.CUDAPlace(0)) def check_static_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data(name="input", shape=[4, 4], dtype="bool") result = paddle.any(x=input) input_np = np.random.randint(0, 2, [4, 4]).astype("bool") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.any(input_np)).all()) def check_static_float_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data( name="input", shape=[4, 4], dtype="float" ) @@ -1826,26 +1841,29 @@ def check_static_float_result(self, place): exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.any(input_np)).all()) def check_static_int_result(self, place): - with base.program_guard(base.Program(), base.Program()): + main = paddle.static.Program() + startup = paddle.static.Program() + with base.program_guard(main, startup): input = paddle.static.data(name="input", shape=[4, 4], dtype="int") result = paddle.any(x=input) input_np = np.random.randint(0, 2, [4, 4]).astype("int") exe = base.Executor(place) fetches = exe.run( - base.default_main_program(), + main, feed={"input": input_np}, fetch_list=[result], ) self.assertTrue((fetches[0] == np.any(input_np)).all()) + @test_with_pir_api def test_static(self): for place in self.places: self.check_static_result(place=place) diff --git a/test/legacy_test/test_reshape_op.py b/test/legacy_test/test_reshape_op.py index f0128173b44894..dd1f7e00447343 100755 --- a/test/legacy_test/test_reshape_op.py +++ b/test/legacy_test/test_reshape_op.py @@ -44,14 +44,14 @@ def init_data(self): self.infered_shape = (12, 10) def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( ["X"], "Out", check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -123,7 +123,7 @@ def init_data(self): self.infered_shape = (12, 10) def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( @@ -131,7 +131,7 @@ def test_check_grad(self): "Out", check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -157,7 +157,7 @@ def init_data(self): self.infered_shape = (12, 10) def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( @@ -165,7 +165,7 @@ def test_check_grad(self): "Out", check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -209,7 +209,7 @@ def init_data(self): self.actual_shape = (2, 3, 20) def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( @@ -217,7 +217,7 @@ def test_check_grad(self): "Out", check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -254,7 +254,7 @@ def init_data(self): self.shape = (-1, -1) def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( @@ -262,7 +262,7 @@ def test_check_grad(self): "Out", check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -308,7 +308,7 @@ def init_data(self): self.infered_shape = (10, 10) def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( @@ -316,7 +316,7 @@ def test_check_grad(self): "Out", check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -370,7 +370,7 @@ def test_check_output(self): base.core.CPUPlace(), atol=1e-5, no_check_set=['XShape'], - check_new_ir=True, + check_pir=True, ) def test_check_grad(self): diff --git a/test/legacy_test/test_run.py b/test/legacy_test/test_run.py index e0ec7c9657fb54..331d45a514a932 100644 --- a/test/legacy_test/test_run.py +++ b/test/legacy_test/test_run.py @@ -207,4 +207,5 @@ def test_ps_4(self): if __name__ == '__main__': + os.environ["FLAGS_dynamic_static_unified_comm"] = "0" unittest.main() diff --git a/test/legacy_test/test_scale_op.py b/test/legacy_test/test_scale_op.py index a6cea49a2bce32..5f33de74b3b614 100644 --- a/test/legacy_test/test_scale_op.py +++ b/test/legacy_test/test_scale_op.py @@ -42,10 +42,10 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) class TestScaleOpScaleVariable(OpTest): @@ -66,10 +66,10 @@ def init_dtype_type(self): pass def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) class TestScaleOpSelectedRows(unittest.TestCase): @@ -150,10 +150,10 @@ def init_dtype_type(self): self.dtype = np.float16 def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_new_ir=True) + self.check_grad(["X"], "Out", check_pir=True) @unittest.skipIf( @@ -172,10 +172,10 @@ def setUp(self): self.outputs = {'Out': convert_float_to_uint16(out)} def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=0.8, check_new_ir=True) + self.check_grad(['X'], 'Out', numeric_grad_delta=0.8, check_pir=True) @unittest.skipIf( diff --git a/test/legacy_test/test_shape_op.py b/test/legacy_test/test_shape_op.py index 6ced0cfd4a8c89..4ee95e9c4f3bde 100644 --- a/test/legacy_test/test_shape_op.py +++ b/test/legacy_test/test_shape_op.py @@ -36,7 +36,7 @@ def config(self): self.dtype = np.float32 def test_check_output(self): - self.check_output(check_cinn=True, check_new_ir=True) + self.check_output(check_cinn=True, check_pir=True) class case1(TestShapeOp): @@ -125,7 +125,7 @@ def config(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_cinn=True, check_new_ir=True) + self.check_output_with_place(place, check_cinn=True, check_pir=True) class case1Bf16(TestShapeOpBf16): diff --git a/test/legacy_test/test_slice_op.py b/test/legacy_test/test_slice_op.py index 065251b246928e..8791bf94c16dc3 100644 --- a/test/legacy_test/test_slice_op.py +++ b/test/legacy_test/test_slice_op.py @@ -67,7 +67,7 @@ def config(self): self.out = self.input[1:3, 0:3, 2:4, :] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -75,7 +75,8 @@ def test_check_grad_normal(self): 'Out', max_relative_error=0.006, check_prim=True, - check_new_ir=True, + check_pir=True, + check_prim_pir=True, ) @@ -125,7 +126,7 @@ def config(self): self.out = self.input[1:2] def test_check_output(self): - self.check_output_with_place(paddle.CPUPlace(), check_new_ir=True) + self.check_output_with_place(paddle.CPUPlace(), check_pir=True) # 1.2 with attr(decrease) @@ -157,7 +158,7 @@ def config(self): self.out = self.input[1:2, 0:3, 2:4, :] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( @@ -165,11 +166,11 @@ def test_check_grad_normal(self): 'Out', max_relative_error=0.006, check_prim=True, - check_new_ir=True, + check_pir=True, + check_prim_pir=True, ) -# Situation 2: starts(list, have tensor), ends(list, no tensor) # without attr(decrease) class TestSliceOp_starts_ListTensor(OpTest): def setUp(self): @@ -203,11 +204,11 @@ def config(self): self.starts_infer = [-1, 0, -1] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_new_ir=True + ['Input'], 'Out', max_relative_error=0.006, check_pir=True ) @@ -248,11 +249,11 @@ def config(self): self.starts_infer = [1, -1, 2] def test_check_output(self): - self.check_output(check_dygraph=True, check_new_ir=True) + self.check_output(check_dygraph=True, check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_new_ir=True + ['Input'], 'Out', max_relative_error=0.006, check_pir=True ) @@ -301,11 +302,11 @@ def config(self): self.out = self.input[1, 0:3, 2:4, :] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_new_ir=True + ['Input'], 'Out', max_relative_error=0.006, check_pir=True ) @@ -339,11 +340,11 @@ def config(self): self.out = self.input[1:3, 0:3, 2:4, :] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_new_ir=True + ['Input'], 'Out', max_relative_error=0.006, check_pir=True ) @@ -378,11 +379,11 @@ def config(self): self.out = self.input[1, 0, 2:4, :] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_new_ir=True + ['Input'], 'Out', max_relative_error=0.006, check_pir=True ) @@ -424,11 +425,11 @@ def config(self): self.ends_infer = [-1, 3, 4] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_new_ir=True + ['Input'], 'Out', max_relative_error=0.006, check_pir=True ) @@ -468,10 +469,10 @@ def config(self): self.out = self.input[0:20, 1:3, 1:3] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['Input'], 'Out', check_new_ir=True) + self.check_grad(['Input'], 'Out', check_pir=True) # Test CUDA float16 @@ -507,7 +508,7 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, check_prim=True, check_new_ir=True + place, check_prim=True, check_pir=True, check_prim_pir=True ) def test_check_grad_normal(self): @@ -519,7 +520,8 @@ def test_check_grad_normal(self): ['Input'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, + check_prim_pir=True, ) @@ -555,7 +557,7 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, check_prim=True, check_new_ir=True + place, check_prim=True, check_pir=True, check_prim_pir=True ) def test_check_grad_normal(self): @@ -567,7 +569,8 @@ def test_check_grad_normal(self): 'Out', numeric_grad_delta=0.5, check_prim=True, - check_new_ir=True, + check_pir=True, + check_prim_pir=True, ) @@ -597,10 +600,16 @@ def config(self): self.infer_flags = [1, 1, 1] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['Input'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['Input'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) # Test python API diff --git a/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py b/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py index 8b6e944e89a19b..0f5e2aee011737 100644 --- a/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py +++ b/test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base, incubate from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -50,11 +51,11 @@ def setUp(self): self.outputs = {'Out': rst} def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0), check_new_ir=True) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad(self): self.check_grad_with_place( - core.CUDAPlace(0), ["X"], "Out", check_new_ir=True + core.CUDAPlace(0), ["X"], "Out", check_pir=True ) @@ -72,14 +73,14 @@ def setUp(self): def test_check_output(self): try: - self.check_output_with_place(core.CPUPlace(), check_new_ir=True) + self.check_output_with_place(core.CPUPlace(), check_pir=True) except (NotImplementedError, RuntimeError): pass def test_check_grad(self): try: self.check_grad_with_place( - core.CPUPlace(), ["X"], "Out", check_new_ir=True + core.CPUPlace(), ["X"], "Out", check_pir=True ) except (NotImplementedError, RuntimeError): pass @@ -92,11 +93,14 @@ class TestDropoutBiasFuseOp2(unittest.TestCase): # test the python side API for softmax_mask_fuse op def setUp(self): np.random.seed(123) - self.dtypes = ['float16', 'float32'] + self.dtypes = ['float32', 'float16'] + @test_with_pir_api def test_static(self): for dtype in self.dtypes: - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input_x = paddle.static.data( name="x", shape=[1, 4, 32, 32], dtype=dtype ) @@ -107,7 +111,7 @@ def test_static(self): exe = base.Executor(base.CUDAPlace(0)) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"x": x_in_np}, fetch_list=[rst], ) diff --git a/test/legacy_test/test_softmax_op.py b/test/legacy_test/test_softmax_op.py index e684daa695a23e..ae98b434766192 100644 --- a/test/legacy_test/test_softmax_op.py +++ b/test/legacy_test/test_softmax_op.py @@ -84,9 +84,17 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode if self.use_cudnn: place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_new_ir=True) + self.check_output_with_place( + place, + atol=1e-5, + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) else: - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_pir=True, check_prim_pir=True + ) def test_check_grad(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode @@ -99,7 +107,8 @@ def test_check_grad(self): "Out", max_relative_error=0.01, check_dygraph=(not self.use_mkldnn), - check_new_ir=True, + check_pir=True, + check_prim_pir=True, ) else: self.check_grad( @@ -108,7 +117,8 @@ def test_check_grad(self): max_relative_error=0.01, check_dygraph=(not self.use_mkldnn), check_prim=True, - check_new_ir=True, + check_pir=True, + check_prim_pir=True, ) @@ -146,9 +156,13 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode if self.use_cudnn: place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_new_ir=True) + self.check_output_with_place( + place, atol=1e-5, check_pir=True, check_prim_pir=True + ) else: - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_pir=True, check_prim_pir=True + ) @unittest.skipIf( @@ -158,6 +172,8 @@ class TestSoftmaxOp_ZeroDim2(TestSoftmaxOp): def setUp(self): self.op_type = "softmax" self.python_api = F.softmax + self.public_python_api = F.softmax + self.prim_op_type = "comp" self.use_cudnn = True self.use_mkldnn = False # explicilty use float32 for ROCm, as MIOpen does not yet support float64 @@ -180,9 +196,17 @@ def test_check_output(self): # TODO(wangzhongpu): support mkldnn op in dygraph mode if self.use_cudnn: place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_new_ir=True) + self.check_output_with_place( + place, + check_prim=True, + atol=1e-5, + check_pir=True, + check_prim_pir=True, + ) else: - self.check_output(check_prim=True, check_new_ir=True) + self.check_output( + check_prim=True, check_pir=True, check_prim_pir=True + ) class TestSoftmaxOp2(TestSoftmaxOp): @@ -357,7 +381,11 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=1e-3, check_new_ir=True + place, + atol=1e-3, + check_prim=True, + check_pir=True, + check_prim_pir=True, ) # FIXME: If the x_shape is [10, 10], gradient failed. @@ -386,7 +414,11 @@ def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=1e-3, check_new_ir=True + place, + atol=1e-3, + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -437,7 +469,8 @@ def test_check_output(self): place, check_dygraph=(not self.use_mkldnn), check_prim=True, - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), + check_prim_pir=(not self.use_mkldnn), ) def test_check_grad(self): @@ -449,7 +482,8 @@ def test_check_grad(self): numeric_grad_delta=0.05, check_dygraph=(not self.use_mkldnn), check_prim=True, - check_new_ir=(not self.use_mkldnn), + check_pir=(not self.use_mkldnn), + check_prim_pir=(not self.use_mkldnn), ) diff --git a/test/legacy_test/test_split_op.py b/test/legacy_test/test_split_op.py index 92dfe72f8443e3..a192078899dd7c 100644 --- a/test/legacy_test/test_split_op.py +++ b/test/legacy_test/test_split_op.py @@ -57,7 +57,7 @@ def _set_op_type(self): self.op_type = "split" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -65,7 +65,7 @@ def test_check_grad(self): ['out0', 'out1', 'out2'], check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -117,7 +117,7 @@ def _set_op_type(self): self.op_type = "split" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -125,7 +125,7 @@ def test_check_grad(self): ['out0', 'out1', 'out2'], check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -160,10 +160,10 @@ def _set_op_type(self): self.op_type = "split" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2'], check_new_ir=True) + self.check_grad(['X'], ['out0', 'out1', 'out2'], check_pir=True) # attr(sections) is list containing Tensor @@ -208,10 +208,10 @@ def _set_op_type(self): self.op_type = "split" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], ['out0', 'out1', 'out2'], check_new_ir=True) + self.check_grad(['X'], ['out0', 'out1', 'out2'], check_pir=True) class TestSplitOp_unk_section(OpTest): @@ -247,7 +247,7 @@ def _set_op_type(self): self.op_type = "split" def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -255,7 +255,7 @@ def test_check_grad(self): ['out0', 'out1', 'out2'], check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -308,7 +308,7 @@ def test_check_grad(self): 'out2', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) cls_name = "{}_{}".format(parent.__name__, "BF16Op") diff --git a/test/legacy_test/test_squeeze2_op.py b/test/legacy_test/test_squeeze2_op.py index 1ee72ad2a39e0b..f1a689024de8a2 100755 --- a/test/legacy_test/test_squeeze2_op.py +++ b/test/legacy_test/test_squeeze2_op.py @@ -56,11 +56,20 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( - no_check_set=['XShape'], check_prim=True, check_new_ir=True + no_check_set=['XShape'], + check_prim=True, + check_pir=True, + check_prim_pir=True, ) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True, check_new_ir=True) + self.check_grad( + ["X"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) def init_dtype(self): self.dtype = np.float64 diff --git a/test/legacy_test/test_stack_op.py b/test/legacy_test/test_stack_op.py index 0b0d73e9fafc18..fb8eda704db6ac 100644 --- a/test/legacy_test/test_stack_op.py +++ b/test/legacy_test/test_stack_op.py @@ -63,11 +63,15 @@ def setUp(self): self.attrs = {'axis': self.axis} def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=False) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): self.check_grad( - self.get_x_names(), 'Y', check_prim=True, check_new_ir=False + self.get_x_names(), + 'Y', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -189,11 +193,15 @@ def setUp(self): self.attrs = {'axis': self.axis} def test_check_output(self): - self.check_output(check_prim=True, check_new_ir=False) + self.check_output(check_prim=True, check_pir=True, check_prim_pir=True) def test_check_grad(self): self.check_grad( - self.get_x_names(), 'Y', check_prim=True, check_new_ir=False + self.get_x_names(), + 'Y', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) diff --git a/test/legacy_test/test_sum_op.py b/test/legacy_test/test_sum_op.py index 910d8a75e5f9f1..d8536bc7719553 100644 --- a/test/legacy_test/test_sum_op.py +++ b/test/legacy_test/test_sum_op.py @@ -61,7 +61,7 @@ def test_check_output(self): self.check_output( check_prim=True, check_cinn=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -71,7 +71,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_cinn=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -310,7 +310,7 @@ def test_check_output(self): check_cinn=True, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) # FIXME: Because of the precision fp16, max_relative_error @@ -324,7 +324,7 @@ def test_check_grad(self): check_cinn=True, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -377,7 +377,7 @@ def test_check_output(self): check_dygraph=False, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def test_check_grad(self): @@ -388,7 +388,7 @@ def test_check_grad(self): check_dygraph=False, check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) diff --git a/test/legacy_test/test_tile_op.py b/test/legacy_test/test_tile_op.py index 40dc04b0537707..4a7d94637c6fa5 100644 --- a/test/legacy_test/test_tile_op.py +++ b/test/legacy_test/test_tile_op.py @@ -21,7 +21,8 @@ import paddle from paddle import base -from paddle.base import Program, core, program_guard +from paddle.base import core +from paddle.pir_utils import test_with_pir_api # Situation 1: repeat_times is a list (without tensor) @@ -47,10 +48,16 @@ def init_data(self): self.repeat_times = [2] def test_check_output(self): - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) class TestTileOpRank_ZeroDim1(TestTileOpRank1): @@ -165,7 +172,7 @@ def init_data(self): self.infer_repeat_times = [-1] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad(['X'], 'Out') @@ -206,7 +213,7 @@ def init_data(self): self.repeat_times = [2] def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad(['X'], 'Out') @@ -235,7 +242,7 @@ def if_enable_cinn(self): self.check_cinn = True def test_check_output(self): - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) class TestTileFP16OP(OpTest): @@ -262,10 +269,16 @@ def init_data(self): self.repeat_times = [2, 1, 4] def test_check_output(self): - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) @unittest.skipIf( @@ -294,7 +307,7 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_cinn=self.check_cinn, check_new_ir=True + place, check_cinn=self.check_cinn, check_pir=True ) def init_data(self): @@ -305,7 +318,12 @@ def init_data(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=True + place, + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, ) @@ -324,7 +342,7 @@ def if_enable_cinn(self): self.check_cinn = True def test_check_output(self): - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) # Situation 56: input x is Integer @@ -344,12 +362,15 @@ def if_enable_cinn(self): self.check_cinn = True def test_check_output(self): - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) class TestTileError(unittest.TestCase): + @test_with_pir_api def test_errors(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): x1 = base.create_lod_tensor( np.array([[-1]]), [[1]], base.CPUPlace() ) @@ -363,8 +384,11 @@ def test_errors(self): class TestTileAPIStatic(unittest.TestCase): + @test_with_pir_api def test_api(self): - with program_guard(Program(), Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): repeat_times = [2, 2] x1 = paddle.static.data(name='x1', shape=[-1, 4], dtype="int32") out = paddle.tile(x1, repeat_times) @@ -490,6 +514,7 @@ def test_dygraph(self): class Testfp16TileOp(unittest.TestCase): + @test_with_pir_api def testfp16(self): input_x = (np.random.random([1, 2, 3])).astype('float16') with paddle.static.program_guard(paddle.static.Program()): diff --git a/test/legacy_test/test_transpose_op.py b/test/legacy_test/test_transpose_op.py index 52f85ef1e0a708..32f071eafb472b 100644 --- a/test/legacy_test/test_transpose_op.py +++ b/test/legacy_test/test_transpose_op.py @@ -49,14 +49,14 @@ def init_op_type(self): self.use_mkldnn = False def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( ['X'], 'Out', check_prim=True, - check_new_ir=True, + check_pir=True, check_prim_pir=True, ) @@ -211,7 +211,7 @@ def init_op_type(self): self.use_mkldnn = False def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) base.core.disable_autotune() def test_check_grad(self): @@ -220,7 +220,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -254,7 +254,7 @@ def init_op_type(self): self.use_mkldnn = False def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) base.core.disable_autotune() def test_check_grad(self): @@ -263,7 +263,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -304,7 +304,7 @@ def init_op_type(self): self.use_mkldnn = False def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) base.core.disable_autotune() def test_check_grad(self): @@ -313,7 +313,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) @@ -346,7 +346,7 @@ def init_op_type(self): self.use_mkldnn = False def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): self.check_grad( @@ -354,7 +354,7 @@ def test_check_grad(self): 'Out', check_prim=True, check_prim_pir=True, - check_new_ir=True, + check_pir=True, ) def initTestCase(self): @@ -394,7 +394,7 @@ def init_op_type(self): self.use_mkldnn = False def test_check_output(self): - self.check_output(no_check_set=['XShape'], check_new_ir=True) + self.check_output(no_check_set=['XShape'], check_pir=True) def test_check_grad(self): pass diff --git a/test/legacy_test/test_tril_triu_op.py b/test/legacy_test/test_tril_triu_op.py index 1c64288dabbe57..d9de52a83999fd 100644 --- a/test/legacy_test/test_tril_triu_op.py +++ b/test/legacy_test/test_tril_triu_op.py @@ -19,7 +19,7 @@ import paddle from paddle import base, tensor from paddle.base import core -from paddle.base.framework import Program, program_guard +from paddle.pir_utils import test_with_pir_api class TrilTriuOpDefaultTest(OpTest): @@ -45,10 +45,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_new_ir=True) + self.check_output(check_pir=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out', check_new_ir=True) + self.check_grad(['X'], 'Out', check_pir=True) def init_dtype(self): self.dtype = np.float64 @@ -86,7 +86,7 @@ def initTestCase(self): self.X = np.arange(1, 101, dtype="float32").reshape([10, -1]) def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0), check_new_ir=True) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad_normal(self): self.check_grad_with_place( @@ -94,7 +94,7 @@ def test_check_grad_normal(self): ['X'], 'Out', numeric_grad_delta=0.05, - check_new_ir=True, + check_pir=True, ) @@ -200,14 +200,15 @@ def initTestCase(self): class TestTrilTriuOpAPI(unittest.TestCase): """test case by using API and has -1 dimension""" + @test_with_pir_api def test_api(self): paddle.enable_static() dtypes = ['float16', 'float32'] for dtype in dtypes: - prog = Program() - startup_prog = Program() - with program_guard(prog, startup_prog): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): data = np.random.random([1, 9, 9, 4]).astype(dtype) x = paddle.static.data( shape=[1, 9, -1, 4], dtype=dtype, name='x' @@ -221,7 +222,7 @@ def test_api(self): ) exe = base.Executor(place) tril_out, triu_out = exe.run( - base.default_main_program(), + prog, feed={"x": data}, fetch_list=[tril_out, triu_out], ) @@ -243,14 +244,15 @@ def test_api_with_dygraph(self): np.testing.assert_allclose(tril_out, np.tril(data), rtol=1e-05) np.testing.assert_allclose(triu_out, np.triu(data), rtol=1e-05) + @test_with_pir_api def test_base_api(self): paddle.enable_static() dtypes = ['float16', 'float32'] for dtype in dtypes: - prog = Program() - startup_prog = Program() - with program_guard(prog, startup_prog): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): data = np.random.random([1, 9, 9, 4]).astype(dtype) x = paddle.static.data( shape=[1, 9, -1, 4], dtype=dtype, name='x' @@ -264,7 +266,7 @@ def test_base_api(self): ) exe = base.Executor(place) triu_out = exe.run( - base.default_main_program(), + prog, feed={"x": data}, fetch_list=[triu_out], ) diff --git a/test/legacy_test/test_uniform_random_op.py b/test/legacy_test/test_uniform_random_op.py index 29011739802f40..1e301f53d7fc2f 100644 --- a/test/legacy_test/test_uniform_random_op.py +++ b/test/legacy_test/test_uniform_random_op.py @@ -69,7 +69,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -101,7 +101,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -121,7 +121,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -141,7 +141,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -170,7 +170,7 @@ def init_attrs(self): self.output_hist = output_hist def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=True) + self.check_output_customized(self.verify_output, check_pir=True) def verify_output(self, outs): hist, prob = self.output_hist(np.array(outs[0])) @@ -244,7 +244,7 @@ def init_attrs(self): self.output_hist = output_hist_diag def test_check_output(self): - self.check_output_customized(self.verify_output, check_new_ir=False) + self.check_output_customized(self.verify_output, check_pir=False) class TestUniformRandomOpSelectedRows(unittest.TestCase): diff --git a/test/legacy_test/test_unique.py b/test/legacy_test/test_unique.py index 8fe9dfa9af6353..808cd8227bb7d4 100644 --- a/test/legacy_test/test_unique.py +++ b/test/legacy_test/test_unique.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestUniqueOp(OpTest): @@ -413,6 +414,7 @@ def test_dygraph_attr_dtype(self): self.assertTrue((inverse.numpy() == np_inverse).all(), True) self.assertTrue((counts.numpy() == np_counts).all(), True) + @test_with_pir_api def test_static_graph(self): with paddle_static_guard(): with paddle.static.program_guard( diff --git a/test/legacy_test/test_unsqueeze2_op.py b/test/legacy_test/test_unsqueeze2_op.py index 36fa88cb1035ac..cb1a6c868671ee 100755 --- a/test/legacy_test/test_unsqueeze2_op.py +++ b/test/legacy_test/test_unsqueeze2_op.py @@ -44,11 +44,20 @@ def if_enable_cinn(self): def test_check_output(self): self.check_output( - no_check_set=["XShape"], check_prim=True, check_new_ir=True + no_check_set=["XShape"], + check_prim=True, + check_pir=True, + check_prim_pir=True, ) def test_check_grad(self): - self.check_grad(["X"], "Out", check_prim=True, check_new_ir=True) + self.check_grad( + ["X"], + "Out", + check_prim=True, + check_pir=True, + check_prim_pir=True, + ) def init_test_case(self): self.ori_shape = (3, 40) @@ -137,10 +146,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=["XShape"], check_new_ir=True) + self.check_output(no_check_set=["XShape"], check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_new_ir=True) + self.check_grad(["X"], "Out", check_pir=True) def init_test_case(self): self.ori_shape = (20, 5) @@ -198,10 +207,10 @@ def setUp(self): } def test_check_output(self): - self.check_output(no_check_set=["XShape"], check_new_ir=True) + self.check_output(no_check_set=["XShape"], check_pir=True) def test_check_grad(self): - self.check_grad(["X"], "Out", check_new_ir=True) + self.check_grad(["X"], "Out", check_pir=True) def init_test_case(self): self.ori_shape = (20, 5) diff --git a/test/legacy_test/test_where_op.py b/test/legacy_test/test_where_op.py index 3685a59b981347..ba9a5fbc3f0e1f 100644 --- a/test/legacy_test/test_where_op.py +++ b/test/legacy_test/test_where_op.py @@ -33,11 +33,11 @@ def setUp(self): self.outputs = {'Out': np.where(self.cond, self.x, self.y)} def test_check_output(self): - self.check_output(check_cinn=self.check_cinn, check_new_ir=True) + self.check_output(check_cinn=self.check_cinn, check_pir=True) def test_check_grad(self): self.check_grad( - ['X', 'Y'], 'Out', check_cinn=self.check_cinn, check_new_ir=True + ['X', 'Y'], 'Out', check_cinn=self.check_cinn, check_pir=True ) def init_config(self): @@ -85,7 +85,7 @@ def setUp(self): def test_check_output(self): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_cinn=self.check_cinn, check_new_ir=True + place, check_cinn=self.check_cinn, check_pir=True ) def test_check_grad(self): @@ -96,7 +96,7 @@ def test_check_grad(self): 'Out', numeric_grad_delta=0.05, check_cinn=self.check_cinn, - check_new_ir=True, + check_pir=True, ) def init_config(self): diff --git a/test/mkldnn/test_activation_mkldnn_op.py b/test/mkldnn/test_activation_mkldnn_op.py index e6ef8388f771d1..d37cea47450c70 100644 --- a/test/mkldnn/test_activation_mkldnn_op.py +++ b/test/mkldnn/test_activation_mkldnn_op.py @@ -482,6 +482,14 @@ def setUp(self): self.outputs = {'Out': out} self.attrs = {"use_mkldnn": True} + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', check_pir=True) + class TestMKLDNNRound_ZeroDim(TestActivation_ZeroDim): def setUp(self): @@ -494,6 +502,14 @@ def setUp(self): self.outputs = {'Out': out} self.attrs = {"use_mkldnn": True} + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad(['X'], 'Out', check_pir=True) + class TestMKLDNNSigmoidDim4(TestSigmoid): def setUp(self): diff --git a/test/mkldnn/test_reduce_bf16_mkldnn_op.py b/test/mkldnn/test_reduce_bf16_mkldnn_op.py index 187ce4cde47393..1d0e0e596dcb89 100644 --- a/test/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/test/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -40,7 +40,7 @@ def setUp(self): self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): - self.check_output(check_dygraph=False, check_new_ir=False) + self.check_output(check_dygraph=False, check_pir=False) def calculate_grads(self): tmp_tensor = np.zeros(self.x_fp32.shape).astype("float32") @@ -84,7 +84,7 @@ def test_check_grad(self): check_dygraph=False, user_defined_grads=[self.grad_X], user_defined_grad_outputs=[convert_float_to_uint16(self.grad_Out)], - check_new_ir=False, + check_pir=False, ) diff --git a/test/mkldnn/test_reduce_mkldnn_op.py b/test/mkldnn/test_reduce_mkldnn_op.py index 3dce2c72e55687..d22556f67630c0 100644 --- a/test/mkldnn/test_reduce_mkldnn_op.py +++ b/test/mkldnn/test_reduce_mkldnn_op.py @@ -29,12 +29,12 @@ def setUp(self): self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): - self.check_output(check_dygraph=False, check_new_ir=False) + self.check_output(check_dygraph=False, check_pir=False) class TestReduceDefaultWithGradOneDNNOp(TestReduceSumDefaultOneDNNOp): def test_check_grad(self): - self.check_grad(['X'], 'Out', check_dygraph=False, check_new_ir=False) + self.check_grad(['X'], 'Out', check_dygraph=False, check_pir=False) class TestReduceSum4DOneDNNOp(TestReduceDefaultWithGradOneDNNOp): diff --git a/test/prim/pir_prim/CMakeLists.txt b/test/prim/pir_prim/CMakeLists.txt index c31e7254ff60c9..049f4b915dc457 100644 --- a/test/prim/pir_prim/CMakeLists.txt +++ b/test/prim/pir_prim/CMakeLists.txt @@ -1,6 +1,6 @@ set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet test_prim_custom_vjp test_prim_jit - test_pir_prim_flags) + test_pir_prim_flags test_sink_decomp) foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES}) py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1 diff --git a/test/prim/pir_prim/test_sink_decomp.py b/test/prim/pir_prim/test_sink_decomp.py new file mode 100644 index 00000000000000..d1a14987123ee9 --- /dev/null +++ b/test/prim/pir_prim/test_sink_decomp.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.autograd.ir_backward import grad +from paddle.base import core +from paddle.decomposition import decompose + +paddle.enable_static() + + +class TestPrimMode(unittest.TestCase): + def setUp(self): + np.random.seed(2023) + self.shape_x = [8, 16, 32, 64] + self.shape_y = [8, 16, 32, 64] + self.x = np.random.random(self.shape_x).astype("float32") + self.y = np.random.random(self.shape_y).astype("float32") + self.prog = None + + def base_net(self, flag=None): + if flag == "forward": + core._set_prim_forward_enabled(True) + elif flag == "backward": + core._set_prim_backward_enabled(True) + elif flag == "all": + core._set_prim_all_enabled(True) + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data('x', self.shape_x, dtype='float32') + y = paddle.static.data('y', self.shape_y, dtype='float32') + x.stop_gradient = False + y.stop_gradient = False + divide_out = paddle.divide(x, y) + sum_out = paddle.mean(divide_out, axis=0) + [new_out] = decompose(main_program, [sum_out]) + gradients = grad(new_out, (x, y)) + + exe = paddle.static.Executor() + [fwd, dx, dy] = exe.run( + feed={'x': self.x, 'y': self.y}, fetch_list=[new_out, gradients] + ) + + whole_ops = [op.name() for op in main_program.global_block().ops] + self.prog = main_program + if flag == "forward": + core._set_prim_forward_enabled(False) + assert ( + 'pd_op.mean' not in whole_ops + and 'pd_op.divide_grad' in whole_ops + ) + elif flag == "backward": + core._set_prim_backward_enabled(False) + assert ( + 'pd_op.mean' in whole_ops + and 'pd_op.divide_grad' not in whole_ops + ) + elif flag == "all": + core._set_prim_all_enabled(False) + assert ( + 'pd_op.mean' not in whole_ops + and 'pd_op.divide_grad' not in whole_ops + ) + else: + assert ( + 'pd_op.mean' in whole_ops and 'pd_op.divide_grad' in whole_ops + ) + return fwd, dx, dy + + def test_prim_forward(self): + res_ref = self.base_net() + res = self.base_net("forward") + for ref, actual in zip(res_ref, res): + np.testing.assert_equal(ref, actual) + + def test_prim_backward(self): + res_ref = self.base_net() + res = self.base_net("backward") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_prim_all(self): + res_ref = self.base_net() + res = self.base_net("all") + for ref, actual in zip(res_ref, res): + np.testing.assert_allclose(ref, actual, rtol=1e-6) + + def test_has_decomp(self): + _ = self.base_net() + for op in self.prog.global_block().ops: + if op.name() == "pd_op.divide": + self.assertEqual(core.has_decomp(op), False) + if op.name() == "pd_op.mean": + self.assertEqual(core.has_decomp(op), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/extract_errors.py b/test/sot/extract_errors.py new file mode 100644 index 00000000000000..b9d9e505724ef0 --- /dev/null +++ b/test/sot/extract_errors.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys + +runtime_error_msg = sys.stdin.read() + +pattern = r'File "?(.*?)"?, line (\d+),.*\n(.*?)\n(.*?)$' +for match in re.finditer(pattern, runtime_error_msg, re.MULTILINE): + file = match.group(1) + if file.startswith("./"): + file = f"tests/{file[2:]}" + line = match.group(2) + error_info = match.group(4) + if "AssertionError" not in error_info: + # error_info = match.group(3) + '\n' + match.group(4) + output = f"::error file={file},line={line}::Error" + print(output) diff --git a/test/sot/test_01_basic.py b/test/sot/test_01_basic.py new file mode 100644 index 00000000000000..8a03ea9fd3ae5a --- /dev/null +++ b/test/sot/test_01_basic.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def foo(x: int, y: paddle.Tensor): + return x + y + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(foo, 1, paddle.to_tensor(2)) + + +def numpy_add(x, y): + out = paddle.to_tensor(x.numpy() + y.numpy()) + return out + + +class TestNumpyAdd(TestCaseBase): + @strict_mode_guard(0) + def test_numpy_add(self): + x = paddle.to_tensor([2]) + y = paddle.to_tensor([3]) + self.assert_results(numpy_add, x, y) + + +if __name__ == "__main__": + unittest.main() + + +# Instructions: +# LOAD_FAST +# BINARY_ADD +# RETURN_VALUE + +# Variables: +# ConstantVariable +# TensorVariable diff --git a/test/sot/test_02_store_inplace.py b/test/sot/test_02_store_inplace.py new file mode 100644 index 00000000000000..3c9b4df4602a05 --- /dev/null +++ b/test/sot/test_02_store_inplace.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def foo(x: int, y: paddle.Tensor): + x = x + 1 + y = y + 1 + x += y + return x + + +class TestStoreInplace(TestCaseBase): + def test_simple(self): + self.assert_results(foo, 1, paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() + + +# Instructions: +# LOAD_FAST +# BINARY_ADD +# STORE_FAST (new) +# INPLACE_ADD (new) +# RETURN_VALUE + +# Variables: +# ConstantVariable +# TensorVariable diff --git a/test/sot/test_03_tuple.py b/test/sot/test_03_tuple.py new file mode 100644 index 00000000000000..797d54384714d0 --- /dev/null +++ b/test/sot/test_03_tuple.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_TUPLE +# BINARY_SUBSCR + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def build_tuple(x: int, y: paddle.Tensor): + x = (x, y) + return x[1] + 1 + + +@check_no_breakgraph +def build_tuple_with_slice_subscript(x: int, y: paddle.Tensor): + z = (x, y, 3, 4) + return z[0:5:1] + + +@check_no_breakgraph +def build_tuple_with_int_subscript(x: int, y: paddle.Tensor): + z = (x, y) + return z[0] + + +@check_no_breakgraph +def tuple_count_int(x: int, y: paddle.Tensor): + z = (x, x, 2, 1) + return z.count(x) + + +def tuple_count_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): + return y.count(x) + + +@check_no_breakgraph +def tuple_index_int(x: int, y: paddle.Tensor): + z = (x, y, x, y, y) + return z.index(x) + + +def tuple_index_tensor(x: paddle.Tensor, y: tuple[paddle.Tensor]): + return y.index(x) + + +class TestBuildTuple(TestCaseBase): + def test_build_tuple(self): + self.assert_results(build_tuple, 1, paddle.to_tensor(2)) + self.assert_results( + build_tuple_with_slice_subscript, 1, paddle.to_tensor(2) + ) + self.assert_results( + build_tuple_with_int_subscript, 1, paddle.to_tensor(2) + ) + + +class TestTupleMethods(TestCaseBase): + def test_tuple_methods_int(self): + self.assert_results(tuple_count_int, 1, paddle.to_tensor(2)) + self.assert_results(tuple_index_int, 1, paddle.to_tensor(2)) + + def test_tuple_methods_tensor(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + self.assert_results(tuple_count_tensor, a, (a, b, a, b)) + self.assert_results(tuple_index_tensor, b, (b, b, b, a)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_04_list.py b/test/sot/test_04_list.py new file mode 100644 index 00000000000000..d8b0823a279c21 --- /dev/null +++ b/test/sot/test_04_list.py @@ -0,0 +1,327 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_LIST (new) +# BINARY_SUBSCR +# DELETE_SUBSCR + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def list_getitem_int(x: int, y: paddle.Tensor): + x = [x, y] + return x[0] + 1 + + +@check_no_breakgraph +def list_getitem_tensor(x: int, y: paddle.Tensor): + x = [x, y] + return x[1] + 1 + + +@check_no_breakgraph +def list_setitem_int(x: int, y: paddle.Tensor): + z = [x, y] + z[0] = 3 + return z + + +def list_setitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + z[1] = paddle.to_tensor(3) + return z + + +@check_no_breakgraph +def list_delitem_int(x: int, y: paddle.Tensor): + z = [x, y] + del z[0] + return z + + +@check_no_breakgraph +def list_delitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + del z[1] + return z + + +@check_no_breakgraph +def list_construct_from_list(x: int, y: paddle.Tensor): + z = [x, y] + return z + + +@check_no_breakgraph +def list_append_int(x: int, y: paddle.Tensor): + z = [x, y] + z.append(3) + return z + + +@check_no_breakgraph +def list_append_tensor(x: int, y: paddle.Tensor): + z = [x, y] + z.append(y) + return z + + +@check_no_breakgraph +def list_clear(x: int, y: paddle.Tensor): + z = [x, y] + z.clear() + return z + + +@check_no_breakgraph +def list_copy(x: int, y: paddle.Tensor): + z = [x, y] + a = z.copy() + z[0] = 3 + z[1] = y + 1 + return (a, z) + + +@check_no_breakgraph +def list_count_int(x: int, y: paddle.Tensor): + z = [x, x, 2, 3, 1] + return z.count(x) + + +def list_count_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.count(x) + + +@check_no_breakgraph +def list_extend(x: int, y: paddle.Tensor): + z = [x, y] + a = [y, x] + b = (x, y) + z.extend(a) + z.extend(b) + return z + + +@check_no_breakgraph +def list_index_int(x: int, y: paddle.Tensor): + z = [x, x, 1, 2] + return z.index(x) + + +def list_index_tensor(x: paddle.Tensor, y: list[paddle.Tensor]): + return y.index(x) + + +@check_no_breakgraph +def list_insert(x: int, y: paddle.Tensor): + z = [x, y] + z.insert(0, x) + z.insert(3, y) + return z + + +@check_no_breakgraph +def list_pop(x: int, y: paddle.Tensor): + z = [x, y] + a = z.pop() + b = z.pop() + return (z, a, b) + + +@check_no_breakgraph +def list_remove(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.remove(x) + z.remove(y) + return z + + +@check_no_breakgraph +def list_reverse(x: int, y: paddle.Tensor): + z = [x, x, y, y] + z.reverse() + return z + + +@check_no_breakgraph +def list_default_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort() + return z + + +@check_no_breakgraph +def list_key_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort(lambda x: x) + return z + + +@check_no_breakgraph +def list_reverse_sort(x: int, y: paddle.Tensor): + z = [x + 2, x, x + 1] + z.sort(reverse=True) + return z + + +@check_no_breakgraph +def list_tensor_sort(x: int, y: paddle.Tensor): + z = [y + 2, y, y + 1] + z.sort() + return z + + +@check_no_breakgraph +def list_max(x: paddle.Tensor | int, y: paddle.Tensor | int): + z = [x, x, y] + return max(z) + + +@check_no_breakgraph +def list_tensor_max_api(x: paddle.Tensor): + return x.max() + + +@check_no_breakgraph +def list_min(x: paddle.Tensor | int, y: paddle.Tensor | int): + z = [x, x, y] + return min(z) + + +@check_no_breakgraph +def list_tensor_min_api(x: paddle.Tensor): + return x.min() + + +@check_no_breakgraph +def list_no_arguments(): + l1 = list() # noqa: C408 + l1.append(1) + l2 = list() # noqa: C408 + l2.append(2) + return l1[0] + l2[0] + + +class TestListBasic(TestCaseBase): + def test_list_basic(self): + self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) + self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_setitem_int, 1, paddle.to_tensor(2) + ) + + +class TestListMethods(TestCaseBase): + def test_list_setitem(self): + self.assert_results_with_side_effects( + list_setitem_tensor, 1, paddle.to_tensor(2) + ) + + def test_list_count_and_index(self): + self.assert_results(list_count_int, 1, paddle.to_tensor(2)) + self.assert_results(list_index_int, 1, paddle.to_tensor(2)) + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + self.assert_results(list_count_tensor, a, [a, b, a, b, a, b]) + self.assert_results(list_index_tensor, b, [a, b, a, b, a, b]) + + def test_list_delitem(self): + self.assert_results_with_side_effects( + list_delitem_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_delitem_tensor, 1, paddle.to_tensor(2) + ) + + def test_list_append(self): + self.assert_results_with_side_effects( + list_append_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_append_tensor, 1, paddle.to_tensor(2) + ) + + def test_list_clear(self): + self.assert_results_with_side_effects( + list_clear, 1, paddle.to_tensor(2) + ) + + def test_list_copy(self): + self.assert_results_with_side_effects(list_copy, 1, paddle.to_tensor(2)) + + def test_list_extend(self): + self.assert_results_with_side_effects( + list_extend, 1, paddle.to_tensor(2) + ) + + def test_list_insert(self): + self.assert_results_with_side_effects( + list_insert, 1, paddle.to_tensor(2) + ) + + def test_list_pop(self): + self.assert_results_with_side_effects(list_pop, 1, paddle.to_tensor(2)) + + def test_list_remove(self): + self.assert_results_with_side_effects( + list_remove, 1, paddle.to_tensor(2) + ) + + def test_list_reverse(self): + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + list_reverse, 1, paddle.to_tensor(2) + ) + + def test_list_sort(self): + self.assert_results_with_side_effects( + list_default_sort, 1, paddle.to_tensor(2) + ) + # TODO: Not currently supported + # self.assert_results_with_side_effects( + # list_tensor_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_key_sort, 1, paddle.to_tensor(2) + # ) + # self.assert_results_with_side_effects( + # list_reverse_sort, 1, paddle.to_tensor(2) + # ) + + def test_list_construct_from_list(self): + self.assert_results(list_construct_from_list, 1, paddle.to_tensor(2)) + + def test_list_max_min(self): + self.assert_results(list_max, 1, 2) + self.assert_results(list_min, 1, 2) + self.assert_results(list_tensor_max_api, paddle.to_tensor([1, 2, 3])) + self.assert_results(list_tensor_min_api, paddle.to_tensor([1, 2, 3])) + + def test_list_noargs(self): + self.assert_results(list_no_arguments) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_05_dict.py b/test/sot/test_05_dict.py new file mode 100644 index 00000000000000..7014a717467984 --- /dev/null +++ b/test/sot/test_05_dict.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_MAP (new) +# BUILD_CONST_KEY_MAP (new) + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def build_map(x: int, y: paddle.Tensor): + z = {x: y} + return z[x] + 1 + + +@check_no_breakgraph +def build_const_key_map(x: int, y: paddle.Tensor): + z = {1: y, 2: y + 1} + return z[x] + 1 + + +@check_no_breakgraph +def dict_get_item(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + return (z.get(1), z.get(2)) + + +@check_no_breakgraph +def dict_get_item_default(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + return (z.get(3, 2), z.get(4, y)) + + +@check_no_breakgraph +def dict_set_item_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z[1] = x * 2 + return z[1] + + +@check_no_breakgraph +def dict_set_item_tensor(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z[2] = y + return z[1] + + +@check_no_breakgraph +def dict_update_item1(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z.update({1: x * 2, 2: y, 3: y + 2}) + return z + + +@check_no_breakgraph +def dict_update_item2(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z.update({1: x * 2, 2: y, 3: z[2] + 2}) + return z + + +@check_no_breakgraph +def dict_del_item_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + del z[1] + return z + + +@check_no_breakgraph +def dict_del_item_tensor(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + del z[2] + return z + + +@check_no_breakgraph +def dict_clear(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z.clear() + return z + + +@check_no_breakgraph +def dict_copy(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + z2 = z.copy() + z[1] = 2 + return z2 + + +@check_no_breakgraph +def dict_setdefault_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + a = z.setdefault(4) + b = z.setdefault(1, 2) + c = z.setdefault(3, 4) + return (z, a, b, c) + + +@check_no_breakgraph +def dict_pop(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1, 3: y} + a = z.pop(1) + b = z.pop(2, 3) + c = z.pop(4, 3) + d = z.pop(5, y) + return (z, a, b, c, d) + + +@check_no_breakgraph +def dict_popitem(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1, 3: y} + a = z.popitem() + return (z, a) + + +@check_no_breakgraph +def dict_construct_from_dict(): + x = {1: 2, 3: 4} + d = dict(x) + return d + + +@check_no_breakgraph +def dict_construct_from_list(): + x = [[1, 2], [3, 4]] + d = dict(x) + return d + + +@check_no_breakgraph +def dict_construct_from_tuple(): + x = ((1, 2), (3, 4)) + d = dict(x) + return d + + +@check_no_breakgraph +def dict_construct_from_comprehension(): + z = {1: 2, 3: 4} + d = {k: v + 1 for k, v in z.items()} + return d + + +@check_no_breakgraph +def dict_no_arguments(): + d1 = dict() # noqa: C408 + d1.update({1: 2}) + d2 = dict() # noqa: C408 + d2.update({3: 4}) + return d1[1] + d2[3] + + +@check_no_breakgraph +def dict_test_fromkeys(x): + d = dict.fromkeys(x) + return d + + +@check_no_breakgraph +def dict_test_fromkeys_defalut(x, y): + d = dict.fromkeys(x, y) + return d + + +class TestBuildDict(TestCaseBase): + def test_build_map(self): + self.assert_results(build_map, 1, paddle.to_tensor(2)) + + def test_build_const_key_map(self): + self.assert_results(build_const_key_map, 1, paddle.to_tensor(2)) + + +class TestDictMethods(TestCaseBase): + def test_dict_get_item(self): + self.assert_results(dict_get_item, 1, paddle.to_tensor(2)) + self.assert_results(dict_get_item_default, 1, paddle.to_tensor(2)) + + def test_dict_set_item(self): + self.assert_results_with_side_effects( + dict_set_item_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + dict_set_item_tensor, 1, paddle.to_tensor(2) + ) + + def test_dict_copy(self): + self.assert_results_with_side_effects(dict_copy, 1, paddle.to_tensor(2)) + + def test_dict_update(self): + self.assert_results_with_side_effects( + dict_update_item1, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + dict_update_item2, 1, paddle.to_tensor(2) + ) + + def test_dict_setdefault(self): + self.assert_results_with_side_effects( + dict_setdefault_int, 1, paddle.to_tensor(2) + ) + + def test_dict_del_item(self): + self.assert_results_with_side_effects( + dict_del_item_int, 1, paddle.to_tensor(2) + ) + self.assert_results_with_side_effects( + dict_del_item_tensor, 1, paddle.to_tensor(2) + ) + + def test_dict_clear(self): + self.assert_results_with_side_effects( + dict_clear, 1, paddle.to_tensor(2) + ) + + def test_dict_pop(self): + self.assert_results_with_side_effects(dict_pop, 1, paddle.to_tensor(2)) + + def test_dict_popitem(self): + self.assert_results_with_side_effects( + dict_popitem, 1, paddle.to_tensor(2) + ) + + def test_construct(self): + self.assert_results(dict_construct_from_dict) + self.assert_results(dict_construct_from_list) + self.assert_results(dict_construct_from_tuple) + self.assert_results(dict_construct_from_comprehension) + + def test_dict_noargs(self): + self.assert_results(dict_no_arguments) + + def test_dict_fromkeys(self): + self.assert_results(dict_test_fromkeys, (1, 2, 3, 4)) + self.assert_results(dict_test_fromkeys, [1, 2, 3, 4]) + self.assert_results(dict_test_fromkeys_defalut, (1, 2, 3, 4), 1) + self.assert_results( + dict_test_fromkeys_defalut, (1, 2, 3, 4), paddle.to_tensor(1) + ) + self.assert_results(dict_test_fromkeys_defalut, [1, 2, 3, 4], 1) + self.assert_results( + dict_test_fromkeys_defalut, [1, 2, 3, 4], paddle.to_tensor(1) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_06_call_function.py b/test/sot/test_06_call_function.py new file mode 100644 index 00000000000000..4358afe6ca985f --- /dev/null +++ b/test/sot/test_06_call_function.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def add(x, y): + return x + y + + +def sub(x, y): + return x - y + + +def foo_1(x: paddle.Tensor): + m = x + 1 + y = add(m * 3, m * 2) + return y + + +def foo_2(x: paddle.Tensor): + m = x + 1 + y = sub(m * 3, m * 2) + return y + + +def foo_3(x: paddle.Tensor): + m = x + 1 + y = sub(m * 3, m * 2) + y = sub(y, y) + y = sub(y, y) + return y + + +def nest_2(x): + return x + 1 + + +def nest_1(x): + return (x - 1) * 2 + + +def foo_4(x: paddle.Tensor): + m = x + 1 + m = nest_1(m) + return m + + +def fn_with_varargs_and_kwargs(x, *args, **kwargs): + return ( + x + + args[0] + + args[1] + - args[2] + + kwargs['a'] * kwargs['b'] / kwargs['c'] + ) + + +def foo_5(x: paddle.Tensor): + m = x + 1 + m = fn_with_varargs_and_kwargs( + m, x + 1, x + 2, x + 3, a=x + 4, b=x + 5, c=x + 6 + ) + return m + + +def fn_with_default_value(x, y=1, z=2): + return x + y + z + + +def foo_6(x: paddle.Tensor): + m = x + 1 + m = fn_with_default_value(m, m + 10) + m = fn_with_default_value(m + 42) + return m + + +def fn_with_default_value_and_varargs_kwargs(x, y=1, *args, **kwargs): + return x + y + args[0] + kwargs['a'] + + +def foo_7(x: paddle.Tensor): + m = x + 1 + m = fn_with_default_value_and_varargs_kwargs(m, m + 1, m + 2, a=m + 3) + return m + + +def fn_with_default_value_and_varargs_kwargs_kwonly_1( + x, y=1, *args, z, **kwargs +): + return x + y + args[0] + kwargs['a'] + z + + +def fn_with_default_value_and_varargs_kwargs_kwonly_2( + x, y=1, *args, z=10, **kwargs +): + return x + y + args[0] + kwargs['a'] + z + + +def foo_8(x: paddle.Tensor): + m = x + 1 + m = fn_with_default_value_and_varargs_kwargs_kwonly_1( + m, m + 1, m + 2, a=m + 3, z=m + 4 + ) + m = fn_with_default_value_and_varargs_kwargs_kwonly_2( + m, m + 1, m + 2, a=m + 3 + ) + return m + + +class TestCall(TestCaseBase): + def test_call1(self): + self.assert_results(foo_1, paddle.to_tensor(2)) + + def test_call2(self): + self.assert_results(foo_2, paddle.to_tensor(3)) + + def test_call3(self): + self.assert_results(foo_3, paddle.to_tensor(4)) + + def test_call4(self): + self.assert_results(foo_4, paddle.to_tensor(5)) + + def test_call5(self): + self.assert_results(foo_5, paddle.to_tensor(6)) + + def test_call6(self): + self.assert_results(foo_6, paddle.to_tensor(7)) + + def test_call7(self): + self.assert_results(foo_7, paddle.to_tensor(8)) + + def test_call8(self): + self.assert_results(foo_8, paddle.to_tensor(9)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_07_unpack.py b/test/sot/test_07_unpack.py new file mode 100644 index 00000000000000..f04a185294b6f5 --- /dev/null +++ b/test/sot/test_07_unpack.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# UNPACK_SEQUENCE (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def unpack_tuple(x: tuple[int, paddle.Tensor]): + y, z = x + return z + 1 + + +def unpack_tensor(x: paddle.Tensor): + a, b = x + return (a, b) + + +def unpack_ex_tuple(x: tuple[int, int, paddle.Tensor]): + *y, z = x + return z + 1 + + +def unpack_ex_tensor(x: paddle.Tensor): + a, b, *c = x + return (a, b) + + +def unpack_ex_tensor_2(x: paddle.Tensor): + a, *b, c, d = x + return (a, c) + + +class TestUnpack(TestCaseBase): + def test_unpack_tuple(self): + self.assert_results(unpack_tuple, (1, paddle.to_tensor(2))) + + def test_unpack_tensor(self): + self.assert_results(unpack_tensor, paddle.to_tensor([2, 3])) + + def test_unpack_ex_tuple(self): + self.assert_results(unpack_ex_tuple, (1, 1, paddle.to_tensor(2))) + + def test_unpack_ex_tensor(self): + self.assert_results(unpack_ex_tensor, paddle.to_tensor([2, 3, 3, 3])) + + def test_unpack_ex_tensor_2(self): + self.assert_results(unpack_ex_tensor_2, paddle.to_tensor([2, 3, 3, 3])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_08_rot.py b/test/sot/test_08_rot.py new file mode 100644 index 00000000000000..2d9146e3ff3baf --- /dev/null +++ b/test/sot/test_08_rot.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def rot_two_return_a(a: paddle.Tensor, b: paddle.Tensor): + b, a = a, b + return a + 1 + + +def rot_two_return_b(a: paddle.Tensor, b: paddle.Tensor): + b, a = a, b + return b + 2 + + +def rot_three_return_a(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): + a, b, c = c, b, a + return a + 1 + + +def rot_three_return_b(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): + a, b, c = c, b, a + return b + 1 + + +def rot_three_return_c(a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor): + a, b, c = c, b, a + return c + 1 + + +def rot_four_return_a( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return a + 1 + + +def rot_four_return_b( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return b + 1 + + +def rot_four_return_c( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return c + 1 + + +def rot_four_return_d( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + a, b, c, d = d, c, b, a + return d + 1 + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor(3) + d = paddle.to_tensor(4) + self.assert_results(rot_two_return_a, a, b) + self.assert_results(rot_two_return_b, a, b) + + self.assert_results(rot_three_return_a, a, b, c) + self.assert_results(rot_three_return_b, a, b, c) + self.assert_results(rot_three_return_c, a, b, c) + + self.assert_results(rot_four_return_a, a, b, c, d) + self.assert_results(rot_four_return_b, a, b, c, d) + self.assert_results(rot_four_return_c, a, b, c, d) + self.assert_results(rot_four_return_d, a, b, c, d) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_09_f_string.py b/test/sot/test_09_f_string.py new file mode 100644 index 00000000000000..c2a3b8144605bf --- /dev/null +++ b/test/sot/test_09_f_string.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# FORMAT_VALUE (new) +# BUILD_STRING (new) +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import assert_true + + +def foo(x: paddle.Tensor): + whilespace = 123 + hello_world = f"Hello {whilespace} World" + z = assert_true(hello_world == "Hello 123 World") + x = x + 1 + return x + + +class TestFString(TestCaseBase): + def test_fstring(self): + self.assert_results(foo, paddle.to_tensor(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_10_build_unpack.py b/test/sot/test_10_build_unpack.py new file mode 100644 index 00000000000000..0b35c469018632 --- /dev/null +++ b/test/sot/test_10_build_unpack.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# BUILD_TUPLE_UNPACK (new) +# BUILD_LIST_UNPACK (new) +# BUILD_TUPLE_UNPACK_WITH_CALL (new) +# CALL_FUNCTION_EX (new) +# BUILD_MAP_UNPACK (new) +# LIST_EXTEND (new) +# LIST_TO_TUPLE (new) +# DICT_UPDATE (new) +# DICT_MERGE (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def build_tuple_unpack(x: tuple[paddle.Tensor], y: tuple[paddle.Tensor]): + z = (*x, *y) + + return z[0] + 1 + + +def build_list_unpack(x: list[paddle.Tensor], y: list[paddle.Tensor]): + z = [*x, *y] + return z[0] + 1 + + +def build_tuple_unpack_with_call( + x: tuple[paddle.Tensor], y: tuple[paddle.Tensor] +): + z = build_tuple_unpack_with_call_inner(*x, *y) + return z[0] + 1 + + +def build_tuple_unpack_with_call_inner( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + z = (a, b, c, d) + return z + + +def build_map_unpack(x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor]): + z = {**x, **y} + return z["a"] + 1 + + +def build_map_unpack_with_call_inner( + a: paddle.Tensor, b: paddle.Tensor, c: paddle.Tensor, d: paddle.Tensor +): + z = {"a": a, "b": b, "c": c, "d": d} + return z + + +def build_map_unpack_with_call( + x: dict[str, paddle.Tensor], y: dict[str, paddle.Tensor] +): + z = build_map_unpack_with_call_inner(**x, **y) + return z["a"] + 1 + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor(3) + d = paddle.to_tensor(4) + + self.assert_results(build_tuple_unpack, (a, b), (c, d)) + self.assert_results(build_list_unpack, [a, b], [c, d]) + self.assert_results(build_tuple_unpack_with_call, (a, b), (c, d)) + self.assert_results( + build_map_unpack, {"a": a, "b": b}, {"c": c, "d": d} + ) + self.assert_results( + build_map_unpack_with_call, {"a": a, "b": b}, {"c": c, "d": d} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_11_jumps.py b/test/sot/test_11_jumps.py new file mode 100644 index 00000000000000..80fa1f4a4eb02b --- /dev/null +++ b/test/sot/test_11_jumps.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +@check_no_breakgraph +def pop_jump_if_false(x: bool, y: paddle.Tensor): + if x: + y += 1 + else: + y -= 1 + return y + + +@check_no_breakgraph +def pop_jump_if_true(x: bool, y: bool, z: paddle.Tensor): + return (x or y) and z + + +@check_no_breakgraph +def jump_if_false_or_pop(x: bool, y: paddle.Tensor): + return x and (y + 1) + + +@check_no_breakgraph +def jump_if_true_or_pop(x: bool, y: paddle.Tensor): + return x or (y + 1) + + +@check_no_breakgraph +def jump_absolute(x: int, y: paddle.Tensor): + while x > 0: + y += 1 + x -= 1 + return y + + +@check_no_breakgraph +def pop_jump_if_none(x: bool, y: paddle.Tensor): + if x is not None: + y += 1 + else: + y -= 1 + return y + + +@check_no_breakgraph +def pop_jump_if_not_none(x: bool, y: paddle.Tensor): + if x is None: + y += 1 + else: + y -= 1 + return y + + +a = paddle.to_tensor(1) +b = paddle.to_tensor(2) +c = paddle.to_tensor(3) +d = paddle.to_tensor(4) + +true_tensor = paddle.to_tensor(True) +false_tensor = paddle.to_tensor(False) + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(jump_absolute, 5, a) + + self.assert_results(pop_jump_if_false, True, a) + self.assert_results(pop_jump_if_false, False, a) + self.assert_results(jump_if_false_or_pop, True, a) + self.assert_results(jump_if_false_or_pop, False, a) + self.assert_results(jump_if_true_or_pop, True, a) + self.assert_results(jump_if_true_or_pop, False, a) + self.assert_results(pop_jump_if_true, True, False, a) + self.assert_results(pop_jump_if_true, False, False, a) + + self.assert_results(pop_jump_if_none, None, a) + self.assert_results(pop_jump_if_none, True, a) + self.assert_results(pop_jump_if_not_none, None, a) + self.assert_results(pop_jump_if_not_none, True, a) + + def test_breakgraph(self): + self.assert_results(pop_jump_if_false, true_tensor, a) + self.assert_results(jump_if_false_or_pop, true_tensor, a) + self.assert_results(jump_if_true_or_pop, false_tensor, a) + self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a) + self.assert_results(jump_absolute, 5, a) + self.assert_results(pop_jump_if_false, false_tensor, a) + self.assert_results(jump_if_false_or_pop, false_tensor, a) + self.assert_results(jump_if_true_or_pop, false_tensor, a) + self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a) + + self.assert_results(pop_jump_if_none, true_tensor, a) + self.assert_results(pop_jump_if_not_none, true_tensor, a) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_12_for_loop.py b/test/sot/test_12_for_loop.py new file mode 100644 index 00000000000000..63e3fedace4bfd --- /dev/null +++ b/test/sot/test_12_for_loop.py @@ -0,0 +1,298 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GET_ITER (new) +# FOR_ITER (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +from paddle.jit import sot +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) + + +def gener(): + yield 1 + yield 2 + yield 3 + + +def for_list_1(x: paddle.Tensor): + for i in [1, 2, 3]: + x += i + + if x > 2: + x += 1 + else: + x -= 1 + return x + + +def for_list_2(x: paddle.Tensor): + for i in [1, 2, 3]: + x += i + + if i > 2: + x += 1 + else: + x -= 1 + return x + + +def for_dict(x: paddle.Tensor): + map = {1: 2, 3: 4} + for k in map.keys(): + x += k + + for v in map.values(): + x += v + + for k, v in map.items(): + x += k + x += v + + return x + + +def for_iter(x, it): + for item in it: + x += item + return x + + +def for_for_fallback(x, it): + for i in [1, 2, 3]: + for item in it: + x += item + return x + + +def for_break(x: paddle.Tensor, it): + for i in [1, 2, 3]: + x += i + if i == 2: + break + for i in it: + x += i + if i == 2: + break + return x + + +def for_continue(x: paddle.Tensor, it): + for i in [1, 2, 3]: + if i == 2: + continue + x += i + + for i in it: + if i == 2: + continue + x += i + return x + + +def for_enumerate_var_with_nested_range(x_array): + x = paddle.tensor.fill_constant([1], 'int32', 0) + x_array = paddle.to_tensor(x_array) + for i, num in enumerate(x_array): + for idx in range(num): + x = x + num + return x + + +def for_create_tmp_in_loop(x, it): + s = x + for i in it: + tmp = i + s += tmp + return s, tmp + + +def for_without_zero_iter(self_res_dict, output): + res_dict = {"logits": output} + for res_key in list(self_res_dict): + res_dict[res_key] = self_res_dict.pop(res_key) + return res_dict + + +@sot.psdb.check_no_fallback +def for_reconstruct_range_iter(): + for i in range(3): + sot.psdb.breakgraph() + + +global_var_name = None + + +def for_tmp_var_with_same_name_as_global_var(): + total = 0 + for i in range(3): + global_var_name = i + 3 + sot.psdb.breakgraph() + total += global_var_name + return total + + +def for_layer_list(layer_list, x): + for net in layer_list: + x = net(x) + return x + + +class TestForLoop(TestCaseBase): + def test_list(self): + a = paddle.to_tensor(1) + self.assert_results(for_list_1, a) + + def test_list_with_fallback(self): + a = paddle.to_tensor(1) + self.assert_results(for_list_2, a) + + def test_dict(self): + a = paddle.to_tensor(1) + self.assert_results(for_dict, a) + + def test_fallback(self): + a = paddle.to_tensor(1) + + sym_output = symbolic_translate(for_iter)(a, gener()) + paddle_output = for_iter(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + def test_for_for_fallback(self): + a = paddle.to_tensor(1) + + sym_output = symbolic_translate(for_iter)(a, gener()) + paddle_output = for_iter(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + def test_for_break(self): + a = paddle.to_tensor(1) + sym_output = symbolic_translate(for_break)(a, gener()) + paddle_output = for_break(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + def test_for_continue(self): + a = paddle.to_tensor(1) + sym_output = symbolic_translate(for_continue)(a, gener()) + paddle_output = for_continue(a, gener()) + self.assert_nest_match(sym_output, paddle_output) + + # TODO(zmh): support range for tensor + # def test_resume_stack(self): + # a = [1, 2, 3] + # self.assert_results(for_enumerate_var_with_nested_range, a) + + def test_create_var_in_loop(self): + x = paddle.to_tensor(1, dtype="float32") + a = [1, 2, 3] + self.assert_results(for_create_tmp_in_loop, x, a) + + sym_output = symbolic_translate(for_create_tmp_in_loop)(x, iter(a)) + paddle_output = for_create_tmp_in_loop(x, iter(a)) + self.assert_nest_match(sym_output, paddle_output) + + def test_create_var_in_loop_with_same_name_as_global(self): + self.assert_results(for_tmp_var_with_same_name_as_global_var) + + def test_for_without_zero_iter(self): + self_res_dict = {} + output = paddle.to_tensor(2) + self.assert_results(for_without_zero_iter, self_res_dict, output) + + def test_reconstruct_range_iter(self): + self.assert_results(for_reconstruct_range_iter) + + def test_layer_list(self): + layers = paddle.nn.LayerList() + for i in range(5): + layers.append(paddle.nn.Linear(5, 5)) + x = paddle.rand([5], dtype="float32") + self.assert_results(for_layer_list, layers, x) + + +def run_list_comp(x): + out = [s.chunk(2, axis=1) for s in x] + return out + + +class TestListComp(TestCaseBase): + def test_list_comp(self): + x = [paddle.randn([1, 4]), paddle.randn([1, 4])] + self.assert_results(run_list_comp, x) + + +def for_enumerate_cache(func_list, x): + out = None + for idx, func in enumerate(func_list): + out = func(x[idx]) + return out + + +class TestEnumerateCache(TestCaseBase): + def test_run(self): + func_list = [ + paddle.nn.Linear(10, 10), + ] + x = [ + paddle.randn([5, 10]), + ] + + out = symbolic_translate(for_enumerate_cache)(func_list, x) + out = symbolic_translate(for_enumerate_cache)(func_list, x) + self.assert_nest_match(OpcodeExecutorCache().translate_count, 1) + + +# after_loop_fn need zzz, and zzz is created as UndefinedVar when generating loop body +# do not set zzz as UndefinedVar again +def undefined_var_case_0(): + for i in [1, 2]: + sot.psdb.breakgraph() + zzz = i + + zzz = zzz + 1 + return zzz + + +# after_loop_fn need create zzz as UndefinedVar +def undefined_var_case_1(): + for i in [1, 2]: + sot.psdb.breakgraph() + aaa = i + + for i in [1, 3]: + zzz = i + zzz = zzz + 1 + return zzz + + +class TestUndefinedVarInRiskyCodes(TestCaseBase): + def test_undefined_var_case_0(self): + self.assert_results(undefined_var_case_0) + + def test_undefined_var_case_1(self): + self.assert_results(undefined_var_case_1) + + +if __name__ == "__main__": + with strict_mode_guard(0): + unittest.main() diff --git a/test/sot/test_13_make_function.py b/test/sot/test_13_make_function.py new file mode 100644 index 00000000000000..9784d7ffad385f --- /dev/null +++ b/test/sot/test_13_make_function.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MAKE_FUNCTION +# CALL_FUNCTION_KW +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def make_fn(x: paddle.Tensor): + def fn(a, b=2, c=3, d=4): + return a + b + c + d + + return fn(1) + fn(2, c=5) + x + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(make_fn, paddle.to_tensor(1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_14_operators.py b/test/sot/test_14_operators.py new file mode 100644 index 00000000000000..fc403ae3ef665f --- /dev/null +++ b/test/sot/test_14_operators.py @@ -0,0 +1,387 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def unary_positive(x: int): + y = +x + return y + + +def unary_negative(x: paddle.Tensor): + y = -x + return y + + +def unary_not(x: paddle.Tensor): + y = not x + return y + + +def unary_invert(x: paddle.Tensor): + y = ~x + return y + + +def binary_power(x: paddle.Tensor, y: paddle.Tensor): + z = x**y + return z + + +def binary_multiply(x: paddle.Tensor, y: paddle.Tensor): + z = x * y + return z + + +def binary_matrix_multiply(x: paddle.Tensor, y: paddle.Tensor): + z = x @ y + return z + + +def binary_floor_divide(x: paddle.Tensor, y: paddle.Tensor): + z = x // y + return z + + +def binary_true_divide(x: paddle.Tensor, y: paddle.Tensor): + z = x / y + return z + + +def binary_modulo(x: paddle.Tensor, y: paddle.Tensor): + z = x % y + return z + + +def binary_add(x: paddle.Tensor, y: paddle.Tensor): + z = x + y + return z + + +def binary_subtract(x: paddle.Tensor, y: paddle.Tensor): + z = x - y + return z + + +def binary_lshift(x: int, y: int): + z = x << y + return z + + +def binary_rshift(x: int, y: int): + z = x >> y + return z + + +def binary_and(x: paddle.Tensor, y: paddle.Tensor): + z = x & y + return z + + +def binary_or(x: paddle.Tensor, y: paddle.Tensor): + z = x | y + return z + + +def binary_xor(x: paddle.Tensor, y: paddle.Tensor): + z = x ^ y + return z + + +def inplace_power(x: paddle.Tensor, y: paddle.Tensor): + x **= y + return x + + +def inplace_multiply(x: paddle.Tensor, y: paddle.Tensor): + x *= y + return x + + +def inplace_matrix_multiply(x: paddle.Tensor, y: paddle.Tensor): + x @= y + return x + + +def inplace_floor_divide(x: paddle.Tensor, y: paddle.Tensor): + x //= y + return x + + +def inplace_true_divide(x: paddle.Tensor, y: paddle.Tensor): + x /= y + return x + + +def inplace_modulo(x: paddle.Tensor, y: paddle.Tensor): + x %= y + return x + + +def inplace_add(x: paddle.Tensor, y: paddle.Tensor): + x += y + return x + + +def inplace_subtract(x: paddle.Tensor, y: paddle.Tensor): + x -= y + return x + + +def inplace_lshift(x: paddle.Tensor, y: int): + x <<= y + return x + + +def inplace_rshift(x: paddle.Tensor, y: int): + x >>= y + return x + + +def inplace_and(x: paddle.Tensor, y: paddle.Tensor): + x &= y + return x + + +def inplace_or(x: paddle.Tensor, y: paddle.Tensor): + x |= y + return x + + +def inplace_xor(x: paddle.Tensor, y: paddle.Tensor): + x ^= y + return x + + +def list_getitem(x: int, y: paddle.Tensor): + z = [x, y] + return operator.getitem(z, 1) + 1 + + +def list_getitem_slice(x: int, y: paddle.Tensor): + z = [x, y] + return operator.getitem(z, slice(0, 2)) + + +def list_setitem_int(x: int, y: paddle.Tensor): + z = [x, y] + operator.setitem(z, 0, 3) + return z + + +def list_setitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + operator.setitem(z, 1, paddle.to_tensor(3)) + return z + + +def list_delitem_int(x: int, y: paddle.Tensor): + z = [x, y] + operator.delitem(z, 0) + return z + + +def list_delitem_tensor(x: int, y: paddle.Tensor): + z = [x, y] + operator.delitem(z, 1) + return z + + +def dict_getitem_int(x: int, y: paddle.Tensor): + z = {1: y, 2: y + 1} + return operator.getitem(z, 1) + + +def dict_getitem_tensor(x: int, y: paddle.Tensor): + z = {1: y, 2: y + 1} + return operator.getitem(z, 2) + + +def dict_setitem_int(x: int, y: paddle.Tensor): + z = {'x': x, 'y': y} + operator.setitem(z, 'x', 2) + return z + + +def dict_setitem_tensor(x: int, y: paddle.Tensor): + z = {'x': x, 'y': y} + operator.setitem(z, 'y', paddle.to_tensor(3)) + return z + + +def dict_delitem_int(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + operator.delitem(z, 1) + return z + + +def dict_delitem_tensor(x: int, y: paddle.Tensor): + z = {1: x, 2: y + 1} + operator.delitem(z, 2) + return z + + +def tuple_getitem_int(x: int, y: paddle.Tensor): + x = (x, y) + return operator.getitem(x, 0) + + +def tuple_getitem_tensor(x: int, y: paddle.Tensor): + x = (x, y) + return operator.getitem(x, 1) + + +def tuple_getitem_slice(x: int, y: paddle.Tensor): + x = (x, y, 1) + return operator.getitem(x, slice(0, 2)) + + +def operator_add(x: int, y: paddle.Tensor): + return operator.add(x, y) + + +def operator_mul(x: int, y: paddle.Tensor): + return operator.mul(x, y) + + +def operator_truth(y: paddle.Tensor): + return operator.truth(y) + + +def operator_is_(x: paddle.Tensor, y: paddle.Tensor): + return (operator.is_(x, x), operator.is_(x, y)) + + +def operator_in_(x: int, y: list): + return x in y + + +def operator_not_in_(x: int, y: list): + return x not in y + + +def operator_is_not(x: paddle.Tensor, y: paddle.Tensor): + return (operator.is_not(x, x), operator.is_not(x, y)) + + +def operator_pos(y: int): + return operator.pos(+y) + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(True) + c = paddle.to_tensor(3) + d = paddle.to_tensor(4) + e = paddle.to_tensor([[1, 2], [3, 4], [5, 6]], dtype='float32') + f = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32') + g = paddle.to_tensor(False) + + self.assert_results(unary_positive, 1) + self.assert_results(unary_negative, a) + self.assert_results(unary_not, b) + self.assert_results(unary_invert, b) + + self.assert_results(binary_power, c, d) + self.assert_results(binary_multiply, c, d) + self.assert_results(binary_matrix_multiply, e, f) + self.assert_results(binary_floor_divide, c, d) + self.assert_results(binary_true_divide, c, d) + self.assert_results(binary_modulo, c, d) + self.assert_results(binary_add, c, d) + self.assert_results(binary_subtract, c, d) + self.assert_results(binary_lshift, 10, 2) + self.assert_results(binary_rshift, 10, 1) + self.assert_results(binary_and, b, g) + self.assert_results(binary_or, b, g) + self.assert_results(binary_xor, b, g) + + self.assert_results(inplace_power, c, d) + self.assert_results(inplace_multiply, c, d) + self.assert_results(inplace_matrix_multiply, e, f) + self.assert_results(inplace_floor_divide, c, d) + self.assert_results(inplace_true_divide, c, d) + self.assert_results(inplace_modulo, c, d) + self.assert_results(inplace_add, c, d) + self.assert_results(inplace_subtract, c, d) + self.assert_results(inplace_lshift, 10, 2) + self.assert_results(inplace_rshift, 10, 1) + self.assert_results(inplace_and, b, g) + self.assert_results(inplace_or, b, g) + self.assert_results(inplace_xor, b, g) + + def test_operator_simple(self): + self.assert_results(operator_add, 1, paddle.to_tensor(2)) + self.assert_results(operator_mul, 1, paddle.to_tensor(2)) + self.assert_results(operator_truth, paddle.to_tensor(2)) + self.assert_results( + operator_is_, paddle.to_tensor(2), paddle.to_tensor(3) + ) + self.assert_results( + operator_is_not, paddle.to_tensor(2), paddle.to_tensor(3) + ) + self.assert_results(operator_pos, 1) + self.assert_results(operator_in_, 12, [1, 2, 12]) + self.assert_results(operator_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) + self.assert_results(operator_not_in_, 12, [1, 2, 3]) + + def test_operator_list(self): + self.assert_results(list_getitem, 1, paddle.to_tensor(2)) + self.assert_results(list_getitem_slice, 1, paddle.to_tensor(2)) + self.assert_results(list_setitem_int, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + list_setitem_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) + self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) + + def test_operator_dict(self): + self.assert_results(dict_getitem_int, 1, paddle.to_tensor(2)) + self.assert_results(dict_getitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(dict_setitem_int, 1, paddle.to_tensor(2)) + self.assert_results_with_side_effects( + dict_setitem_tensor, 1, paddle.to_tensor(2) + ) + self.assert_results(dict_delitem_int, 1, paddle.to_tensor(2)) + self.assert_results(dict_delitem_tensor, 1, paddle.to_tensor(2)) + + def test_operator_tuple(self): + self.assert_results(tuple_getitem_int, 1, paddle.to_tensor(2)) + self.assert_results(tuple_getitem_tensor, 1, paddle.to_tensor(2)) + self.assert_results(tuple_getitem_slice, 1, paddle.to_tensor(2)) + + +def run_not_eq(x: paddle.Tensor, y: int): + out = paddle.reshape(x, [1, -1]) != y + out = out.astype('float32') + return out + + +class TestNotEq(TestCaseBase): + def test_not_eq(self): + x = paddle.to_tensor([2]) + y = 3 + self.assert_results(run_not_eq, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_15_slice.py b/test/sot/test_15_slice.py new file mode 100644 index 00000000000000..b2ee00526f25b7 --- /dev/null +++ b/test/sot/test_15_slice.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# BUILD_SLICE (new) + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +def build_list_slice(x: list, y: paddle.Tensor): + x[2:4] = [0, 1] + return x[0] + y + + +def build_list_slice_with_step(x: list, y: paddle.Tensor): + x[1:5:2] = [0, 1] + return x[0] + y + + +def build_tuple_slice(x: list, y: paddle.Tensor): + x[2:4] = (0, 1) + return x[0] + y + + +def build_tuple_slice_with_step(x: list, y: paddle.Tensor): + x[1:5:2] = (0, 1) + return x[0] + y + + +def tensor_subscript_ellipsis(x: paddle.Tensor, y: paddle.Tensor): + return x[...] + y[...] + + +@check_no_breakgraph +def tensor_subscript_tensor(x: paddle.Tensor): + d0, d1 = paddle.shape(x) + return x[: d0 // 2, d1 // 2 : d1] + + +class TestSlice(TestCaseBase): + def test_simple(self): + x = list(range(10)) + y = paddle.arange(10) + self.assert_results_with_side_effects(build_list_slice, x, y) + self.assert_results_with_side_effects(build_list_slice_with_step, x, y) + self.assert_results_with_side_effects(build_tuple_slice, x, y) + self.assert_results_with_side_effects(build_tuple_slice_with_step, x, y) + + +class MyLayer(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linears = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for i in range(10)] + ) + + def forward(self, x): + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + + +def layer_list_slice(layer, x): + out = layer(x) + return out + + +class TestLayerList(TestCaseBase): + def test_layer_list_slice(self): + layer = MyLayer() + x = paddle.randn([5, 10]) + self.assert_results(layer_list_slice, layer, x) + + +def tensor_slice(x: paddle.Tensor): + return x[1, 1, 1] + 1 + + +class TestTensorSlice(TestCaseBase): + def test_tensor_slice(self): + x = paddle.randn([4, 3, 10]) + self.assert_results(tensor_slice, x) + + +class TestTensorEllipsis(TestCaseBase): + def test_tensor_subscript_ellipsis(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + self.assert_results(tensor_subscript_ellipsis, x, y) + + +class TestTensorSubscriptTensor(TestCaseBase): + def test_tensor_subscript_tensor(self): + x = paddle.rand((10, 10)) + self.assert_results(tensor_subscript_tensor, x) + + +class LayerListNet(paddle.nn.Layer): + def __init__(self) -> None: + super().__init__() + self.layer_list = paddle.nn.LayerList( + [paddle.nn.Linear(5, 5), paddle.nn.Linear(5, 5)] + ) + + def forward(self, x): + out = self.layer_list[0](x) + for layer in self.layer_list[1:]: + out = layer(out) + return out + + +class TestLayerListSlice(TestCaseBase): + def test_layer_list_slice(self): + x = paddle.randn([2, 5]) + net = LayerListNet() + self.assert_results(layer_list_slice, net, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_16_paddle_api.py b/test/sot/test_16_paddle_api.py new file mode 100644 index 00000000000000..9f6e05fa48b2fc --- /dev/null +++ b/test/sot/test_16_paddle_api.py @@ -0,0 +1,60 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.nn.functional import relu + + +def paddle_api_method_call(x: paddle.Tensor): + m = x + 2 + m = paddle.nn.functional.relu(m) + return m + + +def paddle_api_function_call(x: paddle.Tensor): + m = x + 2 + m = relu(m) + return m + + +def paddle_api_function_call_concat( + x: paddle.Tensor, y: paddle.Tensor, axis: int +): + return paddle.concat([x, y], axis=axis) + + +class TestPaddleApiCall(TestCaseBase): + def test_paddle_api_method_call(self): + self.assert_results(paddle_api_method_call, paddle.to_tensor(2.0)) + self.assert_results(paddle_api_method_call, paddle.to_tensor(-5.0)) + self.assert_results(paddle_api_method_call, paddle.to_tensor(0.0)) + + def test_paddle_api_function_call(self): + self.assert_results(paddle_api_function_call, paddle.to_tensor(2.0)) + self.assert_results(paddle_api_function_call, paddle.to_tensor(-5.0)) + self.assert_results(paddle_api_function_call, paddle.to_tensor(0.0)) + + def test_paddle_api_function_call_concat(self): + a = paddle.to_tensor([[1, 2], [3, 4]]) + b = paddle.to_tensor([[5, 6], [7, 8]]) + self.assert_results(paddle_api_function_call_concat, a, b, 0) + self.assert_results(paddle_api_function_call_concat, a, b, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_17_paddle_layer.py b/test/sot/test_17_paddle_layer.py new file mode 100644 index 00000000000000..58b7dfb9fa301d --- /dev/null +++ b/test/sot/test_17_paddle_layer.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear1 = paddle.nn.Linear(10, 1) + + def forward(self, x): + out1 = self.linear1(x) + return out1 + + +class SimpleNet_bound(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear1 = paddle.nn.Linear(10, 1) + + def add(self, x): + return x + 1 + + def forward(self, x): + x = self.add(x) + out1 = self.linear1(x) + return out1 + + +def net_call(x: paddle.Tensor, net): + return net(x) + + +def net_call_passed_by_user(x: paddle.Tensor, net_forward): + return net_forward(x) + + +class SimpleNetWithSequenital(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.seq = paddle.nn.Sequential( + paddle.nn.Linear(10, 10), + paddle.nn.Linear(10, 10), + paddle.nn.Linear(10, 1), + ) + + def forward(self, x): + out1 = self.seq(x) + return out1 + + +class TestLayer(TestCaseBase): + def test_layer(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + net = SimpleNet() + self.assert_results(net_call, x, net) + self.assert_results(net_call, y, net) + self.assert_results(net_call_passed_by_user, x, net.forward) + + def test_layer_with_sequential(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + net = SimpleNetWithSequenital() + self.assert_results(net_call, x, net) + self.assert_results(net_call, y, net) + self.assert_results(net_call_passed_by_user, x, net.forward) + + def test_bound(self): + x = paddle.rand((10,)) + y = paddle.rand((10, 10)) + net = SimpleNet_bound() + self.assert_results(net_call, x, net) + self.assert_results(net_call, y, net) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_18_tensor_method.py b/test/sot/test_18_tensor_method.py new file mode 100644 index 00000000000000..2591db1f748d93 --- /dev/null +++ b/test/sot/test_18_tensor_method.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def tensor_method_call_1(x: paddle.Tensor): + y = x + 1 + return y.mean() + + +def tensor_method_call_2(a: paddle.Tensor, b: paddle.Tensor): + c = a.add(b) + d = c.multiply(a) + e = d.subtract(b) + f = e.divide(a) + g = f.pow(2) + f.abs().sqrt() + h = (g.abs() + 1).log() - (g / g.max()).exp() + i = h.sin() + h.cos() + return i + + +def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): + return func(a) + + +def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): + return ( + a.name, + str(a.place), + a.persistable, + a.dtype, + a.type, + a.is_tensor(), + a.clear_gradient(), + a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(), + ) + + +def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor): + c = a + b + return c.name + + +class TestTensorMethod(TestCaseBase): + def test_tensor_method_1(self): + x = paddle.rand([10]) + y = paddle.rand([2, 4, 6]) + self.assert_results(tensor_method_call_1, x) + self.assert_results(tensor_method_call_1, y) + + def test_tensor_method_2(self): + x = paddle.rand([42]) + y = paddle.rand([42]) + self.assert_results(tensor_method_call_2, x, y) + + def test_tensor_method_passed_by_user(self): + x = paddle.rand([42]) + y = paddle.rand([42]) + self.assert_results(tensor_method_passed_by_user, x, y.add) + + def test_tensor_method_property(self): + x = paddle.rand([42, 24], dtype='float64') + y = paddle.rand([42, 24], dtype='float32') + self.assert_results(tensor_method_property, x, y) + + @unittest.skip("TODO: dynamic tensor name is different") + def test_middle_tensor_name(self): + x = paddle.rand([42, 24]) + y = paddle.rand([42, 24]) + self.assert_results(middle_tensor_name, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_19_closure.py b/test/sot/test_19_closure.py new file mode 100644 index 00000000000000..6191141e07f390 --- /dev/null +++ b/test/sot/test_19_closure.py @@ -0,0 +1,260 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def foo(x: int, y: paddle.Tensor): + z = 3 + + def local(a, b=5): + return a + x + z + b + y + + return local(4) + z + + +def foo2(y: paddle.Tensor, x=1): + """ + Test strip default value + """ + z = 3 + + def local(a, b=5): + return a + x + z + b + y + + return local(4) + + +def foo3(y: paddle.Tensor, x=1): + """ + Test Closure Band Default + """ + z = 3 + + def local(a, b=5): + nonlocal z + z = 4 + return a + x + z + b + y + + return local(4) + + +global_z = 3 + + +def test_global(y: paddle.Tensor): + """ + Test Global variable + """ + + def local(a, b=5): + global global_z + global_z += 1 + return a + global_z + b + y + + return local(1) + + +def multi(c): + return c + 2 + + +def wrapper_function(func): + a = 2 + + def inner(): + return func(a) + + return inner + + +wrapped_multi = wrapper_function(multi) + + +def foo5(y: paddle.Tensor): + """ + Test incoming closures + """ + a = wrapped_multi() + return a + + +def outwrapper(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +def foo6(y: paddle.Tensor): + """ + Test Decorator + """ + + @outwrapper + def load_1(a, b=5): + return a + b + + return load_1(1) + + +import numpy as np + + +def numpy_sum(m): + """ + Test loop call + + Example: a->b->c->a + """ + a = np.array([1, 2, 3]) + tmp = np.sum(a) + return m + 1 + + +def lambda_closure(x, m): + """ + lambda closure. + """ + + def break_graph_closure(): + print("yes") + return x + m + + return break_graph_closure() + + +# motivated by python builtin decorator +def kwargs_wrapper(func): + sig = inspect.signature(func) + + def inner(*args, **kwargs): + return func(*args, **kwargs) + + inner.__signature__ = sig + return inner + + +@kwargs_wrapper +def func7(a, b): + return a + b + + +def foo7(): + return func7(3, 5) + + +def create_closure(): + x = 1 + + def closure(): + return x + 1 + + return closure + + +class TestExecutor(TestCaseBase): + def test_closure(self): + self.assert_results(foo, 1, paddle.to_tensor(2)) + self.assert_results(foo2, paddle.to_tensor(2)) + self.assert_results(foo3, paddle.to_tensor(2)) + self.assert_results_with_global_check( + test_global, ["global_z"], paddle.to_tensor(2) + ) + self.assert_results(foo5, paddle.to_tensor(2)) + self.assert_results(foo6, paddle.to_tensor(2)) + self.assert_results(numpy_sum, paddle.to_tensor(1)) + with strict_mode_guard(0): + self.assert_results( + lambda_closure, paddle.to_tensor(2), paddle.to_tensor(1) + ) + + +class TestExecutor2(TestCaseBase): + def test_closure(self): + self.assert_results(foo7) + + +# Side Effect. +def test_slice_in_for_loop(x, iter_num=3): + x = paddle.to_tensor(x) + a = [] + # Use `paddle.full` so that static analysis can analyze the type of iter_num is Tensor + iter_num = paddle.full( + shape=[1], fill_value=iter_num, dtype="int32" + ) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved + + for i in range(iter_num): + a.append(x) + + for i in range(iter_num): + a[i] = x + out = a[2] + return out + + +class TestExecutor3(TestCaseBase): + def test_closure(self): + tx = paddle.to_tensor([1.0, 2.0, 3.0]) + # need side effect of list. + # self.assert_results(test_slice_in_for_loop, tx) + + +def non_local_test(t: paddle.Tensor): + a = 1 + + def func1(): + nonlocal a + t = a + a = 2 + return t + + def func2(): + nonlocal a + a = 1 + return a + + t += func1() # add 2 + t += func2() # add 1 + t += a # add 1 + return t + + +class TestExecutor4(TestCaseBase): + def test_closure(self): + tx = paddle.to_tensor([1.0]) + self.assert_results(non_local_test, tx) + + +class TestCreateClosure(TestCaseBase): + def test_create_closure(self): + closure = create_closure() + self.assert_results(closure) + + +if __name__ == "__main__": + unittest.main() + +# Instructions: +# LOAD_CLOSURE +# LOAD_DEREF +# LOAD_CLASSDEREF +# STORE_DEREF +# DELETE_DEREF +# STORE_GLOBAL diff --git a/test/sot/test_20_string.py b/test/sot/test_20_string.py new file mode 100644 index 00000000000000..5e628b795afdde --- /dev/null +++ b/test/sot/test_20_string.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import assert_true, check_no_breakgraph + + +def string_format(x: paddle.Tensor): + whilespace = 123 + hello_world = f"Hello {whilespace} World" + z = assert_true(hello_world == "Hello 123 World") + hello_world2 = f"Hello {whilespace}{whilespace} World" + z = assert_true(hello_world2 == "Hello 123123 World") + hello_world_lower = "Hello World".lower() + z = assert_true(hello_world_lower == "hello world") + return x + 1 + + +def string_lower(x: paddle.Tensor): + hello_world_lower = "Hello World".lower() + z = assert_true(hello_world_lower == "hello world") + return x + 1 + + +@check_no_breakgraph +def str_startswith(): + s = "Hello World" + a1 = s.startswith("Hello") + a2 = s.startswith("World") + a3 = s.startswith("Hello World") + a4 = s.startswith("Hello World!") + a5 = s.startswith("Hello", 5) + a6 = s.startswith("Hello", 1, 4) + a7 = s.startswith("Hello", 0, 11) + return (a1, a2, a3, a4, a5, a6, a7) + + +@check_no_breakgraph +def str_endswith(): + s = "Hello World" + a1 = s.endswith("Hello") + a2 = s.endswith("World") + a3 = s.endswith("Hello World") + a4 = s.endswith("Hello World!") + a5 = s.endswith("Hello", 5) + a6 = s.endswith("Hello", 0, 4) + a7 = s.endswith("Hello", 1, 11) + return (a1, a2, a3, a4, a5, a6, a7) + + +class TestExecutor(TestCaseBase): + def test_string_format(self): + self.assert_results(string_format, paddle.to_tensor(1)) + + def test_string_lower(self): + self.assert_results(string_lower, paddle.to_tensor(1)) + + def test_str_startswith(self): + self.assert_results(str_startswith) + + def test_str_endswith(self): + self.assert_results(str_endswith) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_21_global.py b/test/sot/test_21_global.py new file mode 100644 index 00000000000000..131f9c7e367f90 --- /dev/null +++ b/test/sot/test_21_global.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit import sot + +global_x = 1 +global_y = paddle.to_tensor(2) +global_z = None +global_del_val = 1 +global_dict = {} +global_list = [1, 2] +global_inline = 0 + + +def global_func_int(): + global global_x + global_x = global_x + 1 + return global_x + + +def global_func_int_add(): + global global_x + global_x = global_x + global_x + return global_x + global_x + + +def global_func_tensor_int_add(tensor_y: paddle.Tensor): + global global_x + global_x += 1 + return global_x + tensor_y + + +def global_multiple_update(): + global global_x + global_x = 999 + global_x = 888 + global_x = 777 + return global_x - 1 + + +def global_func_tensor(): + global global_y + global_y = global_y + global_y + return global_y + + +def global_func_tensor_add(): + global global_y + global_y = global_y + global_y + return global_y + global_y + + +def global_func(): + global global_x + global global_y + global global_z + + global_z = global_x + global_y + return global_z + + +def global_del_global(): + global global_del_val + + del global_del_val + + +def global_func_dict(): + global global_dict + global_dict["key"] = "value" + global_dict.update({"test_key1": "test_value2"}) + return global_dict + + +def global_func_control1(): + global global_dict + if "key" in global_dict: + del global_dict["key"] + return global_dict + + +def global_func_control2(): + global global_list + for i in range(len(global_list)): + global_list[i] = global_list[i] + 1 + return global_list + + +def global_func_inline_inner_1(): + global global_inline + global_func_inline_inner_2() + global_inline += 1 + + +def global_func_inline_inner_2(): + global global_inline + global_inline += 1 + + +def global_func_inline(): + global_func_inline_inner_1() + global global_inline + return global_inline + + +class TestGlobal(TestCaseBase): + def test_global_func_int(self): + global global_x + self.assert_results_with_global_check(global_func_int, ["global_x"]) + global_x += 1 + self.assert_results_with_global_check(global_func_int, ["global_x"]) + self.assert_results_with_global_check(global_func_int_add, ["global_x"]) + + def test_global_multiple_update(self): + self.assert_results_with_global_check( + global_multiple_update, ["global_x"] + ) + + def test_global_func_tensor_int_add(self): + self.assert_results_with_global_check( + global_func_tensor_int_add, ["global_x"], paddle.to_tensor(1) + ) + + def test_global_func_tensor(self): + self.assert_results_with_global_check(global_func_tensor, ["global_y"]) + self.assert_results_with_global_check( + global_func_tensor_add, ["global_y"] + ) + + def test_global_func(self): + self.assert_results_with_global_check(global_func, ["global_z"]) + self.assertIn("global_del_val", global_del_global.__globals__) + sot.symbolic_translate(global_del_global)() + self.assertNotIn("global_del_val", global_del_global.__globals__) + + def test_global_func_dict(self): + self.assert_results_with_global_check(global_func_dict, ["global_dict"]) + self.assert_results_with_global_check( + global_func_control1, ["global_dict"] + ) + + def test_global_func_list(self): + self.assert_results_with_global_check( + global_func_control2, ["global_list"] + ) + + def test_global_func_inline(self): + global global_inline + global_inline = 0 + sot.symbolic_translate(global_func_inline)() + self.assertEqual(global_inline, 2) + sot.symbolic_translate(global_func_inline)() + self.assertEqual(global_inline, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_analysis_inputs.py b/test/sot/test_analysis_inputs.py new file mode 100644 index 00000000000000..20b32c2225324f --- /dev/null +++ b/test/sot/test_analysis_inputs.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import sys +import unittest + +import paddle +from paddle.jit.sot.opcode_translator.instruction_utils import ( + analysis_inputs, + calc_offset_from_bytecode_offset, + get_instructions, +) + + +def assert_inputs_equals(instruction_offset: int, expected_inputs: set[str]): + current_frame = inspect.currentframe() + assert current_frame is not None + test_frame = current_frame.f_back + assert test_frame is not None + + instructions = get_instructions(test_frame.f_code) + current_instr_idx = calc_offset_from_bytecode_offset( + test_frame.f_lasti + 2, instructions + ) + actual_inputs = analysis_inputs( + instructions, current_instr_idx + instruction_offset + ) + assert ( + set(actual_inputs) == expected_inputs + ), f"actual_inputs: {actual_inputs}, expected_inputs: {expected_inputs}" + + +def case1(x): + m = x + 1 + n = x + 2 + assert_inputs_equals(0, {"x", "n"}) + y = x + 2 + assert_inputs_equals(0, {"n"}) + return n + + +def case2(x): + x = x + 1 + assert_inputs_equals(0, {"x"}) + y = x + 3 + z = x + y + assert_inputs_equals(0, {"x"}) + x += 1 + m = x + 1 + n = x + m + assert_inputs_equals(0, set()) + return 1 + + +def case3(x): + y = x + 1 + + assert_inputs_equals(0, {"x"}) + if x: + z = 1 + else: + z = 2 + return z + + +def case4(x): + y = x + 1 + + assert_inputs_equals(0, {"x", "y"}) + if x: + z = y + else: + z = x + return z + + +def case5(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"z"}) + if z: + a = 1 + else: + b = 2 + return z + + +def case6(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "z"}) + if z: + a = 1 + else: + a += 1 + return z + + +def case7(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "z"}) + if not z: + a += 1 # noqa: F821 + else: + a = 1 + return z + + +def breakgraph_api(x): + return x + + +def normal_api(x): + return x + + +def case8(x): + x = normal_api(x) + assert_inputs_equals(0, {"x"}) + for i in range(10): + x += 1 + if i > 5: + continue + x += 10086 + x += i + return x + + +case9_offset = -9 if sys.version_info >= (3, 11) else -7 + + +def case9(x): + x = breakgraph_api(x) + assert_inputs_equals( + case9_offset, set() + ) # analysis when call breakgraph api (CALL_FUNCTION) + for i in range(10): + x += 1 + if i > 5: + continue + x += 10086 + x += i + return x + + +def case10(x): + assert_inputs_equals(0, {"x", "y"}) + # if x == 0, y will be read before assignment + for i in range(x): + y = i + z = y + + return y + 1 + + +def case11(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "y", "z"}) + if z: + if not y: + a += 1 # noqa: F821 + else: + a = 2 + else: + if y: + a = 1 + else: + a += 1 + return z + + +def case12(x): + y = x + 1 + z = x + 2 + + assert_inputs_equals(0, {"a", "y", "z"}) + if z: + if y: + a = 2 + else: + a += 2 + else: + if y: + a += 1 + else: + a = 1 + return z + + +class TestAnalysisInputs(unittest.TestCase): + def test_case1(self): + case1(paddle.to_tensor([1])) + + def test_case2(self): + case2(paddle.to_tensor([2])) + + def test_case3(self): + case3(paddle.to_tensor([3])) + + def test_case4(self): + case4(paddle.to_tensor([4])) + + def test_case5(self): + case5(paddle.to_tensor([5])) + + def test_case6(self): + case6(paddle.to_tensor([6])) + + def test_case7(self): + case7(paddle.to_tensor([7])) + + def test_case8(self): + case8(paddle.to_tensor([8])) + + def test_case9(self): + case9(paddle.to_tensor([9])) + + def test_case10(self): + case10(paddle.to_tensor([10])) + + def test_case11(self): + case11(paddle.to_tensor([11])) + + def test_case12(self): + case12(paddle.to_tensor([12])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_break_graph.py b/test/sot/test_break_graph.py new file mode 100644 index 00000000000000..cc1aca51caec30 --- /dev/null +++ b/test/sot/test_break_graph.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.utils.paddle_api_config import add_break_graph_apis + + +def ifelse_func(x, y): + if x > 0: + y = y + 1 + else: + y = y + 2 + return y + + +class TestIfElse(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + y = paddle.to_tensor([2.0]) + self.assert_results(ifelse_func, x, y) + + +def multi_output(x: paddle.Tensor): + m = x + 1 + if x > 0: + return m + else: + return 2 * m + + +class TestExecutor(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor(2) + self.assert_results(multi_output, x) + x = paddle.to_tensor(-2) + self.assert_results(multi_output, x) + + +def print_break_graph(x, y): + z = x + y + print(x, z) + out = y * z * 2 + return out + + +class TestPrint(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor(2) + y = paddle.to_tensor(3) + self.assert_results(print_break_graph, x, y) + + +def to_tensor_break_graph(x, y): + z = x + y + out = y * paddle.to_tensor(2) * z + return out + + +class TestToTensor(TestCaseBase): + def test_simple(self): + add_break_graph_apis([paddle.to_tensor]) + x = paddle.to_tensor(2) + y = paddle.to_tensor(3) + self.assert_results(to_tensor_break_graph, x, y) + + +def tensor_clear_gradient(x): + x = paddle.to_tensor(x) + x.clear_gradient() + return x + + +class TestBreakGraphInResumeFn(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor(2) + self.assert_results(tensor_clear_gradient, x) + + +def inner_fn(a, b, c, d): + return a + b * c - d + + +def multi_stack_args(a, b, c): + out = inner_fn(a, b, c, paddle.to_tensor(4)) + return out + + +class TestMultiStackArgs(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor(3) + self.assert_results(multi_stack_args, a, b, c) + + +def break_graph_in_call_method(x): + out = paddle.nn.functional.relu(paddle.to_tensor([4.0])) + return x + out + + +def numpy_break_graph(): + a = paddle.to_tensor([1, 2]) + b = np.sum(a.numpy()) + print(b) + return b + + +class TestBreakGraphInCallMethod(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + break_graph_in_call_method(x) + x = paddle.to_tensor([2.0]) + break_graph_in_call_method(x) + + x = paddle.to_tensor([3.0]) + self.assert_results(break_graph_in_call_method, x) + + def test_numpy(self): + self.assert_results(numpy_break_graph) + + +def test_break_graph_repeat(x): + out = paddle.to_tensor( + paddle.to_tensor(paddle.to_tensor(paddle.to_tensor([1.0]))) + ) + return x + out + + +class TestBreakGraphRepeat(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + test_break_graph_repeat(x) + x = paddle.to_tensor([2.0]) + test_break_graph_repeat(x) + + x = paddle.to_tensor([3.0]) + self.assert_results(test_break_graph_repeat, x) + + +def break_graph_resume_pass_null(x, y): + return paddle.add(x, y[0:50] if y is not None else None) + + +class TestBreakGraphResumePassNull(TestCaseBase): + def test_break_graph_resume_pass_null(self): + x = paddle.rand([50, 50], dtype=paddle.float32) + y = paddle.rand([100, 50], dtype=paddle.float32) + self.assert_results(break_graph_resume_pass_null, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_builtin_dispatch.py b/test/sot/test_builtin_dispatch.py new file mode 100644 index 00000000000000..e4a1ee5fb29993 --- /dev/null +++ b/test/sot/test_builtin_dispatch.py @@ -0,0 +1,329 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import operator +import unittest +import weakref + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph + + +def dispatch_len(x: paddle.Tensor): + return len(x.shape) + + +def dispatch_tensor_len(x: paddle.Tensor): + return len(x) + + +def dispatch_reversed(x: paddle.Tensor | int, y: paddle.Tensor | int): + return list(reversed([x + 1, y - 1, x * 10, y + 1000])) + + +def dispatch_bool(x: paddle.Tensor): + return operator.truth(x.shape) and bool(x.shape) + + +def dispatch_ceil(x: paddle.Tensor | float): + return math.ceil(x) + 1 + + +def dispatch_floor(x: paddle.Tensor | float): + return math.floor(x) + 1 + + +def test_sum_tuple(x: paddle.Tensor | int, y: paddle.Tensor | int): + return sum((x, y)) + + +def test_sum_tuple2( + x: paddle.Tensor | int | list[int] | list[paddle.Tensor], + y: paddle.Tensor | int | list[int] | list[paddle.Tensor], +): + return sum((x, y), x) + + +def test_sum_tuple3(x): + return sum((), x) + + +def test_sum_list(x: paddle.Tensor | int, y: paddle.Tensor | int): + return sum([x, y]) + + +def test_sum_list2( + x: paddle.Tensor | int | list[int] | list[paddle.Tensor], + y: paddle.Tensor | int | list[int] | list[paddle.Tensor], +): + return sum([x, y], x) + + +def test_sum_list3(x): + return sum([], x) + + +def test_tensor_sum(x: paddle.Tensor): + return sum(x) + + +def test_tensor_sum_api(x: paddle.Tensor): + return x.sum() + + +def test_pow(x: paddle.Tensor | int, y: paddle.Tensor | int): + return pow(x, y) + + +def test_pow2(x: paddle.Tensor | int, y: paddle.Tensor | int): + return pow(x, y, 1) + + +def test_tensor_pow_api(x: paddle.Tensor, y: paddle.Tensor | int): + return x.pow(y) + + +def test_math_pow(x: int, y: int): + return math.pow(x, y) + + +def test_chr(x: int | hex | paddle.Tensor): + return chr(x) + + +def test_ord(x: str): + return ord(x) + + +@check_no_breakgraph +def test_sqrt(x: int): + return math.sqrt(x) + + +class TestBuiltinDispatch(TestCaseBase): + def test_dispatch_len(self): + self.assert_results(dispatch_len, paddle.to_tensor([1, 2, 3])) + + def test_dispatch_bool(self): + self.assert_results(dispatch_bool, paddle.to_tensor([1, 2, 3])) + + def test_dispatch_tensor_len(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + dispatch_tensor_len, paddle.to_tensor([1, 2, 3]) + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + dispatch_tensor_len, paddle.to_tensor([4, 5, 6]) + ) + self.assertEqual(ctx.translate_count, 1) + + def test_dispatch_list_reversed(self): + self.assert_results(dispatch_reversed, paddle.to_tensor(1), 2) + self.assert_results(dispatch_reversed, 2, paddle.to_tensor(1)) + + def test_dispatch_tensor_reversed(self): + self.assert_results( + dispatch_reversed, + paddle.to_tensor([1, 2]), + paddle.to_tensor([3, 4]), + ) + + def test_not_dispatch_tensor_ceil(self): + # ceil should break graph, since it returns a int rather than a tensor + self.assert_results(dispatch_ceil, paddle.to_tensor(1.2)) + + def test_dispatch_float_ceil(self): + self.assert_results(dispatch_ceil, 1.2) + + def test_not_dispatch_tensor_floor(self): + # floor should break graph, since it returns a int rather than a tensor + self.assert_results(dispatch_floor, paddle.to_tensor(1.2)) + + def test_dispatch_float_floor(self): + self.assert_results(dispatch_floor, 1.2) + + def test_dispatch_sum(self): + self.assert_results(test_sum_tuple, 1, 1) + self.assert_results(test_sum_tuple, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_tuple, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_tuple2, 1, 1) + self.assert_results(test_sum_tuple2, [1, 2], [3, 4]) + self.assert_results(test_sum_tuple2, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_tuple2, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple2, + [paddle.to_tensor(1), paddle.to_tensor(2)], + [paddle.to_tensor(3), paddle.to_tensor(4)], + ) + self.assert_results( + test_sum_tuple2, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_tuple2, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_tuple3, 1) + self.assert_results(test_sum_tuple3, paddle.to_tensor(1)) + self.assert_results(test_sum_list, 1, 1) + self.assert_results(test_sum_list, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_list, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_list2, 1, 1) + self.assert_results(test_sum_list2, [1, 2], [3, 4]) + self.assert_results(test_sum_list2, paddle.to_tensor(1), 1) + self.assert_results( + test_sum_list2, paddle.to_tensor(1), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list2, + [paddle.to_tensor(1), paddle.to_tensor(2)], + [paddle.to_tensor(3), paddle.to_tensor(4)], + ) + self.assert_results( + test_sum_list2, paddle.to_tensor([1, 2]), paddle.to_tensor(1) + ) + self.assert_results( + test_sum_list2, paddle.to_tensor([1, 2]), paddle.to_tensor([1, 3]) + ) + self.assert_results(test_sum_list3, 1) + self.assert_results(test_sum_list3, paddle.to_tensor(1)) + self.assert_results(test_tensor_sum, paddle.to_tensor([1, 2])) + self.assert_results(test_tensor_sum, paddle.to_tensor((1, 2))) + self.assert_results(test_tensor_sum_api, paddle.to_tensor([1, 2])) + self.assert_results(test_tensor_sum_api, paddle.to_tensor((1, 2))) + + def test_dispatch_pow(self): + self.assert_results(test_pow, 2, 3) + self.assert_results(test_pow, paddle.to_tensor(2), 3) + self.assert_results(test_pow, paddle.to_tensor(2), paddle.to_tensor(3)) + self.assert_results(test_pow2, 2, 3) + self.assert_results(test_math_pow, 2, 3) + self.assert_results(test_tensor_pow_api, paddle.to_tensor(2), 3) + self.assert_results( + test_tensor_pow_api, paddle.to_tensor(2), paddle.to_tensor(3) + ) + + def test_dispatch_chr(self): + self.assert_results(test_chr, 65) + self.assert_results(test_chr, 0x41) + self.assert_results(test_chr, paddle.to_tensor(65)) + self.assert_results(test_chr, paddle.to_tensor(0x41)) + + def test_dispatch_ord(self): + self.assert_results(test_ord, "a") + + def test_dispatch_sqrt(self): + self.assert_results(test_sqrt, 9) + + +def run_getattr(x: paddle.Tensor): + attr = 'dtype' + out = getattr(x, attr) + return out + + +class TestGetattr(TestCaseBase): + def test_getattr(self): + x = paddle.to_tensor(4) + self.assert_results(run_getattr, x) + + +def tensor_hasattr(x: paddle.Tensor): + return ( + hasattr(x, "dtype"), + hasattr(x, "stop_gradient"), + hasattr(x, "abs"), + hasattr(x, "non_tensor_attr"), + ) + + +class ObjectHasattr: + def __init__(self): + attr1 = 1 + attr2 = "2" + attr3 = [3] + + +def object_hasattr(x: ObjectHasattr): + return ( + hasattr(x, "attr1"), + hasattr(x, "attr2"), + hasattr(x, "attr3"), + hasattr(x, "non_obj_attr"), + ) + + +def layer_hasattr(layer: paddle.nn.Layer): + return ( + hasattr(layer, "parameters"), + hasattr(layer, "sublayers"), + hasattr(layer, "non_layer_attr"), + ) + + +class TestHasattr(TestCaseBase): + def test_tensor_hasattr(self): + x = paddle.to_tensor(4) + self.assert_results(tensor_hasattr, x) + + def test_object_hasattr(self): + x = ObjectHasattr() + self.assert_results(object_hasattr, x) + + def test_layer_hasattr(self): + x = paddle.nn.Layer() + self.assert_results(layer_hasattr, x) + + +class WeakrefableObject: + ... + + +def weakref_breakgraph(obj): + return weakref.ref(obj) + + +class TestWeakref(TestCaseBase): + def test_weakref_breakgraph(self): + obj = WeakrefableObject() + self.assert_results(weakref_breakgraph, obj) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_call_object.py b/test/sot/test_call_object.py new file mode 100644 index 00000000000000..486f3591f43269 --- /dev/null +++ b/test/sot/test_call_object.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + +patched = lambda self, x: x * self.a + +patched2 = lambda self, x: x * self.a + 3 + + +class A: + def __init__(self, a): + self.a = a + + def __call__(self, x): + return self.add(x) + + def add(self, x): + return x + self.a + + multi = patched + + +class B: + def __init__(self, a): + self.a = A(a) + + def __call__(self, x, func): + return getattr(self.a, func)(x) + + def self_call(self, x, func): + return getattr(self.a, func)(self.a, x) + + +def foo_1(a, x): + return a(x) + + +def foo_2(a, x): + return a.multi(x) + + +def foo_3(b, x): + return b(x, "multi") + + +def foo_4(b, x): + return b(x, "add") + + +def foo_5(b, x): + return b.self_call(x, "multi") + + +class TestExecutor(TestCaseBase): + def test_simple(self): + c = B(13) + c.a.multi = patched2 + self.assert_results(foo_1, A(13), paddle.to_tensor(2)) + self.assert_results(foo_2, A(13), paddle.to_tensor(2)) + self.assert_results(foo_3, B(13), paddle.to_tensor(2)) + self.assert_results(foo_4, B(13), paddle.to_tensor(2)) + self.assert_results(foo_5, c, paddle.to_tensor(2)) + self.assert_results(foo_4, c, paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_case_base.py b/test/sot/test_case_base.py new file mode 100644 index 00000000000000..03ce3c98227e8a --- /dev/null +++ b/test/sot/test_case_base.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import copy +import inspect +import os +import types +import unittest + +import numpy as np + +import paddle +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) + + +@contextlib.contextmanager +def test_instruction_translator_cache_context(): + cache = OpcodeExecutorCache() + cache.clear() + yield cache + cache.clear() + + +def github_action_error_msg(msg: str): + if 'GITHUB_ACTIONS' in os.environ: + frame = inspect.currentframe() + if frame is not None: + # find the first frame that is in the test folder + while frame.f_back is not None: + filename = frame.f_code.co_filename + if filename.startswith("./"): + filename = f"tests/{filename[2:]}" + lineno = frame.f_lineno + output = f"\n::error file={filename},line={lineno}::{msg}" + return output + frame = frame.f_back + return None + + +class TestCaseBase(unittest.TestCase): + def assertIs(self, x, y, msg=None): + super().assertIs(x, y, msg=msg) + if msg is None: + msg = f"Assert Is, x is {x}, y is {y}" + msg = github_action_error_msg(msg) + if msg is not None: + print(msg) + + def assertEqual(self, x, y, msg=None): + super().assertEqual(x, y, msg=msg) + if msg is None: + msg = f"Assert Equal, x is {x}, y is {y}" + msg = github_action_error_msg(msg) + if msg is not None: + print(msg) + + def assert_nest_match(self, x, y): + cls_x = type(x) + cls_y = type(y) + msg = f"type mismatch, x is {cls_x}, y is {cls_y}" + self.assertIs(cls_x, cls_y, msg=msg) + + container_types = (tuple, list, dict, set) + if cls_x in container_types: + msg = f"length mismatch, x is {len(x)}, y is {len(y)}" + self.assertEqual( + len(x), + len(y), + msg=msg, + ) + if cls_x in (tuple, list): + for x_item, y_item in zip(x, y): + self.assert_nest_match(x_item, y_item) + elif cls_x is dict: + for x_key, y_key in zip(x.keys(), y.keys()): + self.assert_nest_match(x_key, y_key) + self.assert_nest_match(x[x_key], y[y_key]) + elif cls_x is set: + # TODO: Nested set is not supported yet + self.assertEqual(x, y) + elif cls_x in (np.ndarray, paddle.Tensor): + # TODO: support assert_allclose github error log + np.testing.assert_allclose(x, y) + else: + self.assertEqual(x, y) + + def assert_results(self, func, *inputs): + sym_output = symbolic_translate(func)(*inputs) + paddle_output = func(*inputs) + self.assert_nest_match(sym_output, paddle_output) + + def assert_results_with_side_effects(self, func, *inputs): + sym_inputs = copy.deepcopy(inputs) + sym_output = symbolic_translate(func)(*sym_inputs) + paddle_inputs = copy.deepcopy(inputs) + paddle_output = func(*paddle_inputs) + self.assert_nest_match(sym_inputs, paddle_inputs) + self.assert_nest_match(sym_output, paddle_output) + + def assert_results_with_global_check( + self, func, global_keys: list[str], *inputs + ): + def copy_fn(fn): + return types.FunctionType( + code=fn.__code__, + globals=copy.copy(fn.__globals__), + name=fn.__name__, + argdefs=fn.__defaults__, + closure=fn.__closure__, + ) + + sym_copied_fn = copy_fn(func) + sym_fn = symbolic_translate(sym_copied_fn) + paddle_fn = copy_fn(func) + sym_output = sym_fn(*inputs) + paddle_output = paddle_fn(*inputs) + for key in global_keys: + self.assert_nest_match( + sym_copied_fn.__globals__[key], paddle_fn.__globals__[key] + ) + self.assert_nest_match(sym_output, paddle_output) + + +@contextlib.contextmanager +def strict_mode_guard(value): + if "STRICT_MODE" not in os.environ: + os.environ["STRICT_MODE"] = "0" + old_value = os.environ["STRICT_MODE"] + os.environ["STRICT_MODE"] = str(value) + yield + os.environ["STRICT_MODE"] = old_value + + +@contextlib.contextmanager +def cost_model_guard(value): + if "COST_MODEL" not in os.environ: + os.environ["COST_MODEL"] = "True" + old_value = os.environ["COST_MODEL"] + os.environ["COST_MODEL"] = str(value) + yield + os.environ["COST_MODEL"] = old_value diff --git a/test/sot/test_code_status.py b/test/sot/test_code_status.py new file mode 100644 index 00000000000000..9fec5712c2293a --- /dev/null +++ b/test/sot/test_code_status.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +from paddle.jit import sot +from paddle.jit.sot.opcode_translator.skip_files import skip_function +from paddle.jit.sot.utils.code_status import CodeState, CodeStatus + + +class SimpleNet1(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.layers = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(30)] + ) + + def forward(self, x): + for i in range(len(self.layers)): + sot.psdb.breakgraph() + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + return x + + +class SimpleNet2(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.layers = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(30)] + ) + + def forward(self, x): + sot.psdb.fallback() + for i in range(len(self.layers)): + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + x = self.layers[i](x) + return x + + +def run_net(net, x): + for i in range(20): + x = net(x) + return x + + +class TestCodeInfo(TestCaseBase): + def test_case_1(self): + CodeStatus().clear() + net = SimpleNet1() + inp = paddle.rand((10, 10)) + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # run_net, forward, loop body, resumed part2 in loop body + assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4 + # resumed part1 in loop body + assert ( + len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1 + ) + + def test_case_2(self): + with strict_mode_guard(0): + CodeStatus().clear() + net = SimpleNet2() + inp = paddle.rand((10, 10)) + self.assert_results(run_net, net, inp) + code_map = CodeStatus().code_map + states = [] + for k, v in code_map.items(): + if k.co_name.startswith("#") or k.co_name.startswith("$"): + states.append(v) + elif k in CodeStatus().WITH_GRAPH_API: + assert v.state == CodeState.WITH_GRAPH + else: + assert v.state == CodeState.WITHOUT_GRAPH + # no graph found because fallback (paddle api will not enter simulate) + assert ( + len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0 + ) + + +def no_skip_func_0(x): + return x + 1 + + +def skipped_func_0(): + pass + + +def skipped_func_1(x): + return x + 1 + + +def skipped_func_2(x): + return no_skip_func_0(x) + + +def call_skipped_func_0(x): + for i in range(15): + skipped_func_0() + x = skipped_func_1(x) + x = skipped_func_2(x) + return x + + +skip_function(skipped_func_0) +skip_function(skipped_func_1) +skip_function(skipped_func_2) +skip_function(call_skipped_func_0) + + +class TestDisableSkippedFrame(TestCaseBase): + def test_case_0(self): + CodeStatus().clear() + x = paddle.to_tensor([1]) + self.assert_results(call_skipped_func_0, x) + code_map = CodeStatus().code_map + assert ( + code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert ( + code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH + ) + assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_constant_graph.py b/test/sot/test_constant_graph.py new file mode 100644 index 00000000000000..970f9f49024131 --- /dev/null +++ b/test/sot/test_constant_graph.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# New Supported Instructions: +# BUILD_MAP (new) +# BUILD_CONST_KEY_MAP (new) + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def func_1(format_str, tensor): + str = format_str.format(xx=12) + a = "{xx} = 12".format + ttt = f"{10} = 12" + a(xx=12) + tensor = tensor + 1 + return str, tensor + + +def func_2(format_str, tensor): + str = format_str % 10 + tensor = tensor + 1 + return str, tensor + + +class TestConstantGraph(TestCaseBase): + def test_case_1(self): + x = "{xx} is xx" + tensor = paddle.to_tensor(1) + self.assert_results(func_1, x, tensor) + + def test_case_2(self): + x = "%s is xx" + tensor = paddle.to_tensor(1) + self.assert_results(func_2, x, tensor) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_cost_model.py b/test/sot/test_cost_model.py new file mode 100644 index 00000000000000..07899a03efbfd6 --- /dev/null +++ b/test/sot/test_cost_model.py @@ -0,0 +1,114 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +from test_case_base import TestCaseBase, cost_model_guard + +import paddle +from paddle.jit.sot import psdb, symbolic_translate +from paddle.jit.sot.utils import StepInfoManager, StepState + + +def dyn_fast(x, net, iter_): + for i in iter_: + x = net(x) + return x + + +def sot_fast_with_single_graph(x, net): + if not psdb.in_sot(): + time.sleep(0.1) + return x + 1 + + +def sot_fast_with_multi_graph(x, net): + if not psdb.in_sot(): + time.sleep(0.1) + x = x + 1 + psdb.breakgraph() + x = x + 2 + return x + + +class Net(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.linear = paddle.nn.Linear(10, 10) + + def forward(self, x): + if not psdb.in_sot(): + time.sleep(0.1) + x = x / 3 + x = x + 5 + x = self.linear(x) + return x + + +class TestCostModel(TestCaseBase): + @cost_model_guard("True") + def test_dyn_fast(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + sot_fn = symbolic_translate(dyn_fast) + for i in range(60): + sot_fn(x, net, iter(range(10))) + + state = StepInfoManager().step_record[dyn_fast.__code__].state + assert state == StepState.RUN_DYN + + @cost_model_guard("True") + def test_sot_fast_with_multi_graph(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + sot_fn = symbolic_translate(sot_fast_with_multi_graph) + for i in range(30): + sot_fn(x, net) + + state = ( + StepInfoManager() + .step_record[sot_fast_with_multi_graph.__code__] + .state + ) + assert state == StepState.RUN_SOT + + @cost_model_guard("True") + def test_sot_fast_with_single_graph(self): + x = paddle.rand([10]) + net = paddle.nn.Linear(10, 10) + for i in range(30): + symbolic_translate(sot_fast_with_single_graph)(x, net) + + state = ( + StepInfoManager() + .step_record[sot_fast_with_single_graph.__code__] + .state + ) + assert state == StepState.RUN_SOT + + @cost_model_guard("True") + def test_net(self): + x = paddle.rand([10]) + net = Net() + net = paddle.jit.to_static(net, enable_fallback=True) + for i in range(30): + x = net(x) + + state = StepInfoManager().step_record[Net.forward.__code__].state + assert state == StepState.RUN_SOT + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_delete_fast.py b/test/sot/test_delete_fast.py new file mode 100644 index 00000000000000..9dca7d4ea1b14c --- /dev/null +++ b/test/sot/test_delete_fast.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def test_delete_fast(a): + a = a + 2 + t = a * 3 + del t + return a + + +class TestExecutor(TestCaseBase): + def test_simple(self): + a = paddle.to_tensor(1) + self.assert_results(test_delete_fast, a) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_dup_top.py b/test/sot/test_dup_top.py new file mode 100644 index 00000000000000..5cb28a2dc6ceac --- /dev/null +++ b/test/sot/test_dup_top.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def func_dup_top_1(): + return True == True != False + + +def func_dup_top_2(x): + y = x + 1 + return True == True != False + + +def func_dup_top_two(x: list[paddle.Tensor]): + x[0] += x[1] + return x + + +class TestDupTop(TestCaseBase): + def test_dup_top(self): + self.assert_results(func_dup_top_1) + self.assert_results(func_dup_top_2, paddle.to_tensor(1.0)) + # TODO: fix this after we support side effect + # self.assert_results( + # func_dup_top_two, [paddle.to_tensor(1.0), paddle.to_tensor(2.0)] + # ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_enumerate.py b/test/sot/test_enumerate.py new file mode 100644 index 00000000000000..f81a451da55c99 --- /dev/null +++ b/test/sot/test_enumerate.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def test_enumerate_1(x: int, y: int): + for id, val in enumerate(range(x)): + if id % 2 == 0: + y += val + return y + + +def test_enumerate_2(x: list): + return list(enumerate(x)) + + +def test_enumerate_3(x: list): + return tuple(enumerate(x)) + + +def test_enumerate_4(x: paddle.Tensor): + sum = 0 + for idx, val in enumerate(x): + sum += val + return sum + + +# TODO(zmh): support range for tensor +def test_enumerate_5(x: paddle.Tensor): + sum = 0 + + for idx, val in enumerate(x): + for i in range(val): + sum += val + return sum + + +def test_enumerate_6(x: paddle.Tensor): + sum = 0 + + for idx, val in enumerate(x): + for i in range(idx): + sum += val + return sum + + +def test_enumerate_7(x: paddle.Tensor): + sum = 0 + x = x.flatten() + for idx, val in enumerate(x): + sum += val + return sum + + +# TODO(zmh): support -1 +def test_enumerate_8(x: paddle.Tensor): + sum = 0 + x = paddle.nonzero(x, as_tuple=False) + for idx, val in enumerate(x): + sum += val + return sum + + +def test_enumerate_10(layer_list, x): + sum = 0 + for idx, layer in enumerate(layer_list): + sum += layer(x) + return sum + + +class TestExecutor(TestCaseBase): + def test_cases(self): + x = 8 + y = 5 + ty = paddle.randn((10, 10)) + layer_list = paddle.nn.LayerList( + [paddle.nn.Linear(10, 10) for _ in range(3)] + ) + + self.assert_results(test_enumerate_1, x, y) + self.assert_results(test_enumerate_2, [2, 4, 6, 8, 10]) + self.assert_results(test_enumerate_3, [2, 4, 6, 8, 10]) + + self.assert_results(test_enumerate_4, ty) + # TODO(zmh): support range for tensor + + with strict_mode_guard(0): + self.assert_results(test_enumerate_5, paddle.to_tensor([1, 2, 3])) + self.assert_results(test_enumerate_6, paddle.to_tensor([1, 2, 3])) + self.assert_results(test_enumerate_7, ty) + # TODO(zmh): support -1 + + with strict_mode_guard(0): + self.assert_results(test_enumerate_8, ty) + + self.assert_results(test_enumerate_10, layer_list, paddle.randn((10,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_error_handling.py b/test/sot/test_error_handling.py new file mode 100644 index 00000000000000..c74436f0d44f4f --- /dev/null +++ b/test/sot/test_error_handling.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +from paddle.jit import sot + + +def fn_with_try_except(): + sot.psdb.breakgraph() + sot.psdb.fallback() + try: + raise ValueError("ValueError") + except ValueError: + print("catch ValueError") + return True + + +class TestErrorHandling(TestCaseBase): + @strict_mode_guard(0) + def test_fn_with_try_except(self): + self.assert_results(fn_with_try_except) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_exception.py b/test/sot/test_exception.py new file mode 100644 index 00000000000000..26e0f55044379d --- /dev/null +++ b/test/sot/test_exception.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +import unittest + +import paddle +from paddle.jit.sot import symbolic_translate + + +def case1(x): + return n # noqa: F821 + + +def case2(x): + x = x + 1 + return x @ x + + +def case3(x): + y = x.undefined_attr + return y + + +def case4_inner(x): + y = x * 2 + print() + y = y + 1 + return y.undefined_attr + + +def case4(x): + return case4_inner(x) + + +def case5_inner3(x): + x += 1 + print(x) + z = x + 1 + return z + + +def case5_inner2(x): + x += 1 + z = case5_inner3(1 / 0) + return z + 1 + + +def case5_inner1(x): + return case5_inner2(x) + + +def case5(x): + y = case5_inner3(x) + return case5_inner1(y) + 1 + + +class TestException(unittest.TestCase): + def catch_error(self, func, inputs, error_lines: int | list[int]): + if isinstance(error_lines, int): + error_lines = [error_lines] + try: + symbolic_translate(func)(inputs) + except Exception as e: + match_results = re.compile(r'File ".*", line (\d+)').findall(str(e)) + match_results = list(map(int, match_results)) + assert ( + match_results == error_lines + ), f"{match_results} is not equal {error_lines}" + + def test_all_case(self): + self.catch_error(case1, paddle.rand([2, 1]), 25) + # TODO: support runtime error, such as x[111], x@x + # self.catch_error(case2, paddle.rand([2, 1]), 30) + self.catch_error(case3, paddle.rand([2, 1]), 34) + self.catch_error(case4, paddle.rand([2, 1]), 42) + self.catch_error(case5, paddle.rand([3, 1]), [68, 63, 58]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_execution_base.py b/test/sot/test_execution_base.py new file mode 100644 index 00000000000000..8c16b89ec4cf18 --- /dev/null +++ b/test/sot/test_execution_base.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot import symbolic_translate +from paddle.static import BuildStrategy + + +def func(x, y): + ret = 2 * x + ret = paddle.nn.functional.relu(ret) + ret = ret + y + return ret + + +def simple(x): + ret = 2 * x + return ret + + +class TestExecutor(TestCaseBase): + def test_simple(self): + x = paddle.to_tensor([1.0]) + y = paddle.to_tensor([2.0]) + self.assert_results(simple, x) + self.assert_results(simple, y) + + +def foo(x): + out = x + 1 + out = out * 2 + out = paddle.nn.functional.relu(out) + return out + + +class TestBackend(TestCaseBase): + def test_backend(self): + x = paddle.randn([2, 3]) + dy_out = foo(x) + sot_out = symbolic_translate( + foo, build_strategy=BuildStrategy(), backend='CINN' + )(x) + self.assert_nest_match(dy_out, sot_out) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_guard_outputs.py b/test/sot/test_guard_outputs.py new file mode 100644 index 00000000000000..c717eb8190e5fc --- /dev/null +++ b/test/sot/test_guard_outputs.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def non_operator_related_fn(x: int, y: int): + return x + y + + +def partial_non_operator_related_fn(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + y + return [a, z + z] + + +def guard_inputs(x: int, y: int, z: int): + return x + y + z + + +class TestGuardOutputs(TestCaseBase): + def test_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results(non_operator_related_fn, 1, 2) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(non_operator_related_fn, 3, 4) + self.assertEqual(ctx.translate_count, 2) + + def test_partial_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + partial_non_operator_related_fn, + paddle.to_tensor(1), + paddle.to_tensor(2), + 3, + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + partial_non_operator_related_fn, + paddle.to_tensor(4), + paddle.to_tensor(5), + 6, + ) + self.assertEqual(ctx.translate_count, 2) + + def test_guard_inputs(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results(guard_inputs, 1, 2, 3) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(guard_inputs, 0, 2, 3) + self.assertEqual(ctx.translate_count, 2) + self.assert_results(guard_inputs, 1, 0, 3) + self.assertEqual(ctx.translate_count, 3) + self.assert_results(guard_inputs, 1, 2, 0) + self.assertEqual(ctx.translate_count, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_guard_user_defined_fn.py b/test/sot/test_guard_user_defined_fn.py new file mode 100644 index 00000000000000..193164b06f58d6 --- /dev/null +++ b/test/sot/test_guard_user_defined_fn.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def test_guard_fn(fn, inp): + if fn is None: + return 0 + else: + return fn(inp) + + +class TestGuardOutputs(TestCaseBase): + def test_non_operator_related_fn(self): + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + test_guard_fn, + paddle.nn.functional.relu, + paddle.to_tensor([1.0, -1.0]), + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + test_guard_fn, + paddle.nn.functional.gelu, + paddle.to_tensor([1.0, -1.0]), + ) + self.assertEqual(ctx.translate_count, 2) + self.assert_results( + test_guard_fn, + paddle.nn.functional.relu, + paddle.to_tensor([-1.0, -1.0]), + ) + self.assertEqual(ctx.translate_count, 2) + self.assert_results( + test_guard_fn, None, paddle.to_tensor([-1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 3) + + deleted_cnt = 0 + + class Callable: + def __call__(self, var): + return paddle.nn.functional.relu(var) + + def __del__(self): + nonlocal deleted_cnt + deleted_cnt += 1 + + fn1 = Callable() + fn2 = Callable() + with test_instruction_translator_cache_context() as ctx: + self.assert_results( + test_guard_fn, fn1, paddle.to_tensor([1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 1) + self.assert_results( + test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 2) + self.assert_results( + test_guard_fn, fn2, paddle.to_tensor([1.0, -1.0]) + ) + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_inplace_api.py b/test/sot/test_inplace_api.py new file mode 100644 index 00000000000000..767368e9fe7dd4 --- /dev/null +++ b/test/sot/test_inplace_api.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot import symbolic_translate + + +def simple(x, y): + x[0] = 3.0 + z = [y] + y[1] = 5.0 + return x[0] + x[1] + z[0][1] + y[0] + y[1] + + +def inplace_in_if(x, y, z): + if z: + x[0] = 3.0 + z = [y] + y[1] = 5.0 + ret = x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + else: + return None + + +def inplace_in_if_fallback(x, y, z): + if z > 0: + x[0] = 3.0 + z = [y] + y[1] = 5.0 + ret = x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + else: + return None + + +def inplace_in_loop(x, y): + ret = 0 + for i in range(10): + x[0] = 1 + z = [y] + y[1] = 2 * i + 1 + ret += x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + + +def inplace_in_loop_fallback(x, y, it): + ret = 0 + for i in it: + x[0] = 1 + z = [y] + y[1] = 2 * i + 1 + ret += x[0] + x[1] + z[0][1] + y[0] + y[1] + return ret + + +def inplace_case_0(x): + x[:] = 1.0 + return x + + +def inplace_case_1(x): + x[0][0, 0::2] = 1.0 + return x + + +def inplace_case_2(x): + t = x[0] + t[:, 0::2] = t[:, 0::2] * 0 + t[:, 1::2] = t[:, 1::2] + 2 + return x + + +class TestExecutor(TestCaseBase): + def test_case(self): + self.assert_results(inplace_case_0, paddle.randn((1, 4))) + self.assert_results(inplace_case_1, [paddle.randn((1, 4))]) + self.assert_results(inplace_case_2, [paddle.randn((1, 4))]) + + def test_backward(self): + @symbolic_translate + def func(x): + m = x * 2 + n = x * 3 + y = m + y[:] = n + return y + + x = paddle.ones((1, 4)) * 4 + x.stop_gradient = False + y = func(x) + y.sum().backward() + assert (x.grad.numpy() == 3).all() + + def test_simple(self): + self.assert_results( + simple, paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]) + ) + + def test_if(self): + self.assert_results( + inplace_in_if, + paddle.to_tensor([1.0, 2.0]), + paddle.to_tensor([3.0, 4.0]), + True, + ) + self.assert_results( + inplace_in_if_fallback, + paddle.to_tensor([1.0, 2.0]), + paddle.to_tensor([3.0, 4.0]), + paddle.to_tensor(1), + ) + + def test_loop(self): + self.assert_results( + inplace_in_loop, + paddle.to_tensor([1.0, 2.0]), + paddle.to_tensor([3.0, 4.0]), + ) + + a = range(10) + sym_output = symbolic_translate(inplace_in_loop_fallback)( + paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a) + ) + paddle_output = inplace_in_loop_fallback( + paddle.to_tensor([1.0, 2.0]), paddle.to_tensor([3.0, 4.0]), iter(a) + ) + self.assert_nest_match(sym_output, paddle_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_instruction_translator_cache.py b/test/sot/test_instruction_translator_cache.py new file mode 100644 index 00000000000000..6ee1b33ebbc15f --- /dev/null +++ b/test/sot/test_instruction_translator_cache.py @@ -0,0 +1,165 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import random +import types +import unittest +from unittest.mock import patch + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +from paddle.jit.sot.opcode_translator.custom_code import CustomCode +from paddle.jit.sot.opcode_translator.executor.executor_cache import ( + OpcodeExecutorCache, +) + + +def fake_frames() -> ( + tuple[ + types.FrameType, + types.FrameType, + types.FrameType, + types.FrameType, + types.FrameType, + ] +): + def fake_inner_fn_1(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_2(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_3(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_4(): + frame = inspect.currentframe() + assert frame is not None + return frame + + def fake_inner_fn_5(): + frame = inspect.currentframe() + assert frame is not None + return frame + + return ( + fake_inner_fn_1(), + fake_inner_fn_2(), + fake_inner_fn_3(), + fake_inner_fn_4(), + fake_inner_fn_5(), + ) + + +( + FRAME_1, + FRAME_2, + FRAME_3, + FRAME_4, + FRAME_5, +) = fake_frames() + + +def mock_start_translate(frame: types.FrameType, **kwargs): + translate_map = { + FRAME_1: (CustomCode(FRAME_2.f_code, False), lambda frame: True), + FRAME_3: ( + CustomCode(FRAME_4.f_code, False), + lambda frame: False, + ), # Always re-compile + FRAME_5: (CustomCode(None, False), lambda frame: True), + } + return translate_map[frame] + + +class TestOpcodeExecutorCache(unittest.TestCase): + def reset(self): + global translate_count + translate_count = 0 + OpcodeExecutorCache().clear() + + @patch( + "paddle.jit.sot.opcode_translator.executor.executor_cache.start_translate", + mock_start_translate, + ) + def test_cache_hit(self): + with test_instruction_translator_cache_context() as ctx: + translated_code_1 = OpcodeExecutorCache()(FRAME_1) + assert translated_code_1 is not None + self.assertEqual(translated_code_1.code, FRAME_2.f_code) + self.assertEqual(ctx.translate_count, 1) + # cache hit + translated_code_2 = OpcodeExecutorCache()(FRAME_1) + assert translated_code_2 is not None + self.assertEqual(translated_code_2.code, FRAME_2.f_code) + self.assertEqual(ctx.translate_count, 1) + + @patch( + "paddle.jit.sot.opcode_translator.executor.executor_cache.start_translate", + mock_start_translate, + ) + def test_cache_miss_due_to_unknown_code(self): + with test_instruction_translator_cache_context() as ctx: + translated_code_1 = OpcodeExecutorCache()(FRAME_1) + assert translated_code_1 is not None + self.assertEqual(translated_code_1.code, FRAME_2.f_code) + self.assertEqual(ctx.translate_count, 1) + # cache miss + translated_code_2 = OpcodeExecutorCache()(FRAME_3) + assert translated_code_2 is not None + self.assertEqual(translated_code_2.code, FRAME_4.f_code) + self.assertEqual(ctx.translate_count, 2) + + @patch( + "paddle.jit.sot.opcode_translator.executor.executor_cache.start_translate", + mock_start_translate, + ) + def test_cache_miss_due_to_check_failed(self): + with test_instruction_translator_cache_context() as ctx: + translated_code_1 = OpcodeExecutorCache()(FRAME_3) + assert translated_code_1 is not None + self.assertEqual(translated_code_1.code, FRAME_4.f_code) + self.assertEqual(ctx.translate_count, 1) + # cache miss + translated_code_2 = OpcodeExecutorCache()(FRAME_3) + assert translated_code_2 is not None + self.assertEqual(translated_code_2.code, FRAME_4.f_code) + self.assertEqual(ctx.translate_count, 2) + + +def foo(x): + return x + 1 + + +class TestCacheExceedLimit(TestCaseBase): + def test_cache_exceed_limit(self): + for _ in range(30): + input = random.random() + self.assert_results(foo, input) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/sot/test_map.py b/test/sot/test_map.py new file mode 100644 index 00000000000000..812ab36673be42 --- /dev/null +++ b/test/sot/test_map.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from typing import Iterable + +from test_case_base import TestCaseBase, strict_mode_guard + +from paddle.jit import sot +from paddle.jit.sot.psdb import check_no_breakgraph + + +def double_num(num: float | int): + return num * 2 + + +def double_num_with_breakgraph(num: float | int): + sot.psdb.breakgraph() + return num * 2 + + +@check_no_breakgraph +def test_map_list(x: list): + return list(map(double_num, x)) + + +@check_no_breakgraph +def test_map_list_comprehension(x: list): + return [i for i in map(double_num, x)] # noqa: C416 + + +@check_no_breakgraph +def test_map_tuple(x: tuple): + return tuple(map(double_num, x)) + + +@check_no_breakgraph +def test_map_tuple_comprehension(x: tuple): + return [i for i in map(double_num, x)] # noqa: C416 + + +@check_no_breakgraph +def test_map_range(x: Iterable): + return list(map(double_num, x)) + + +@check_no_breakgraph +def test_map_range_comprehension(x: Iterable): + return [i for i in map(double_num, x)] # noqa: C416 + + +def add_dict_prefix(key: str): + return f"dict_{key}" + + +@check_no_breakgraph +def test_map_dict(x: dict): + return list(map(add_dict_prefix, x)) + + +@check_no_breakgraph +def test_map_dict_comprehension(x: dict): + return [i for i in map(add_dict_prefix, x)] # noqa: C416 + + +def test_map_list_with_breakgraph(x: list): + return list(map(double_num_with_breakgraph, x)) + + +@check_no_breakgraph +def test_map_unpack(x: list): + a, b, c, d = map(double_num, x) + return a, b, c, d + + +@check_no_breakgraph +def test_map_for_loop(x: list): + res = 0 + for i in map(double_num, x): + res += i + return res + + +class TestMap(TestCaseBase): + def test_map(self): + self.assert_results(test_map_list, [1, 2, 3, 4]) + self.assert_results(test_map_tuple, (1, 2, 3, 4)) + self.assert_results(test_map_range, range(5)) + self.assert_results(test_map_dict, {"a": 1, "b": 2, "c": 3}) + + def test_map_comprehension(self): + self.assert_results(test_map_list_comprehension, [1, 2, 3, 4]) + self.assert_results(test_map_tuple_comprehension, (1, 2, 3, 4)) + self.assert_results(test_map_range_comprehension, range(5)) + self.assert_results( + test_map_dict_comprehension, {"a": 1, "b": 2, "c": 3} + ) + + def test_map_with_breakgraph(self): + with strict_mode_guard(0): + self.assert_results(test_map_list_with_breakgraph, [1, 2, 3, 4]) + + def test_map_unpack(self): + self.assert_results(test_map_unpack, [1, 2, 3, 4]) + + def test_map_for_loop(self): + self.assert_results(test_map_for_loop, [7, 8, 9, 10]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_multiple_args.py b/test/sot/test_multiple_args.py new file mode 100644 index 00000000000000..7d5bf6b59205c7 --- /dev/null +++ b/test/sot/test_multiple_args.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def foo(x, y): + ret = x + y + return ret + + +class TestMultipleArgs(TestCaseBase): + def test_multiple_args(self): + x = paddle.to_tensor([1.0]) + y = paddle.to_tensor([2.0]) + self.assert_results(foo, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_mutable_data.py b/test/sot/test_mutable_data.py new file mode 100644 index 00000000000000..2cedee2d8529fd --- /dev/null +++ b/test/sot/test_mutable_data.py @@ -0,0 +1,354 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.jit.sot.opcode_translator.executor.mutable_data import ( + MutableData, + MutableDictLikeData, + MutableListLikeData, +) + + +class VariableBase: + def __init__(self): + ... + + +class ConstVariable(VariableBase): + def __init__(self, value): + self.value = value + + def __repr__(self): + return f"ConstVariable({self.value})" + + def __eq__(self, other): + if not isinstance(other, ConstVariable): + return False + return self.value == other.value + + +class DictVariable(VariableBase): + def __init__(self, data): + self.data = data + self.proxy = MutableDictLikeData(data, DictVariable.proxy_getter) + + @staticmethod + def proxy_getter(proxy, key): + if key not in proxy.original_data: + return MutableData.Empty() + return ConstVariable(proxy.original_data[key]) + + def getitem(self, key): + res = self.proxy.get(key) + if isinstance(res, MutableData.Empty): + raise KeyError(f"Key {key} not found") + return res + + def setitem(self, key, value): + self.proxy.set(key, value) + + def delitem(self, key): + self.proxy.delete(key) + + +class ListVariable(VariableBase): + def __init__(self, data): + self.data = data + self.proxy = MutableListLikeData(data, ListVariable.proxy_getter) + + @staticmethod + def proxy_getter(proxy, key): + if key < 0 or key >= len(proxy.original_data): + return MutableData.Empty() + return ConstVariable(proxy.original_data[key]) + + def getitem(self, key): + if isinstance(key, int): + res = self.proxy.get(key) + if isinstance(res, MutableData.Empty): + raise IndexError(f"Index {key} out of range") + return res + elif isinstance(key, slice): + return self.proxy.get_all()[key] + else: + raise TypeError(f"Invalid key type {type(key)}") + + def __getitem__(self, key): + return self.getitem(key) + + def setitem(self, key, value): + if isinstance(key, int): + self.proxy.set(key, value) + elif isinstance(key, slice): + start, end, step = key.indices(self.proxy.length) + indices = list(range(start, end, step)) + if step == 1: + # replace a continuous range + for i, idx in enumerate(indices): + self.proxy.delete(idx - i) + for i, item in enumerate(value): + self.proxy.insert(start + i, item) + else: + # replace some elements + if len(indices) != len(value): + raise ValueError( + f"Attempt to replace {len(indices)} items with {len(value)}" + ) + for i, idx in enumerate(indices): + self.proxy.set(idx, value[i]) + + def delitem(self, key): + self.proxy.delete(key) + + def insert(self, index, value): + self.proxy.insert(index, value) + + def append(self, value): + self.proxy.insert(self.proxy.length, value) + + def extend(self, value): + for item in value: + self.append(item) + + def pop(self, index=-1): + res = self.getitem(index) + self.delitem(index) + return res + + def clear(self): + for i in range(self.proxy.length): + self.delitem(0) + + def remove(self, value): + for i in range(self.proxy.length): + if self.getitem(i) == value: + self.delitem(i) + return + raise ValueError(f"Value {value} not found") + + def sort(self, key=None, reverse=False): + if key is None: + key = lambda x: x + permutation = list(range(self.proxy.length)) + permutation.sort( + key=lambda x: key(self.getitem(x).value), reverse=reverse + ) + self.proxy.permutate(permutation) + + def reverse(self): + permutation = list(range(self.proxy.length)) + permutation.reverse() + self.proxy.permutate(permutation) + + +class TestMutableDictLikeVariable(unittest.TestCase): + def test_getitem(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + self.assertEqual(var.getitem("a"), ConstVariable(1)) + self.assertEqual(var.getitem("b"), ConstVariable(2)) + + def test_setitem(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + var.setitem("a", ConstVariable(3)) + self.assertEqual(var.getitem("a"), ConstVariable(3)) + var.setitem("c", ConstVariable(4)) + self.assertEqual(var.getitem("c"), ConstVariable(4)) + + def test_delitem(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + var.delitem("a") + with self.assertRaises(KeyError): + var.getitem("a") + + def test_keys(self): + data = {"a": 1, "b": 2} + var = DictVariable(data) + self.assertEqual(list(var.proxy.get_all().keys()), ["a", "b"]) + + +class TestMutableListLikeVariable(unittest.TestCase): + def test_getitem(self): + data = [1, 2, 3] + var = ListVariable(data) + self.assertEqual(var.getitem(0), ConstVariable(1)) + self.assertEqual(var.getitem(1), ConstVariable(2)) + self.assertEqual(var.getitem(2), ConstVariable(3)) + + def test_getitem_slice_1(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + self.assertEqual( + var.getitem(slice(0, 3)), + [ConstVariable(1), ConstVariable(2), ConstVariable(3)], + ) + self.assertEqual( + var.getitem(slice(4, 1, -1)), + [ConstVariable(5), ConstVariable(4), ConstVariable(3)], + ) + self.assertEqual( + var.getitem(slice(1, 5, 2)), + [ConstVariable(2), ConstVariable(4)], + ) + + def test_getitem_slice_2(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + self.assertEqual( + var[0:3], + [ConstVariable(1), ConstVariable(2), ConstVariable(3)], + ) + self.assertEqual( + var[4:1:-1], + [ConstVariable(5), ConstVariable(4), ConstVariable(3)], + ) + self.assertEqual( + var[1:5:2], + [ConstVariable(2), ConstVariable(4)], + ) + + def test_setitem(self): + data = [1, 2, 3] + var = ListVariable(data) + var.setitem(0, ConstVariable(4)) + self.assertEqual(var.getitem(0), ConstVariable(4)) + var.append(ConstVariable(5)) + self.assertEqual(var.getitem(3), ConstVariable(5)) + + def test_setitem_slice_1(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + var.setitem(slice(0, 3), [ConstVariable(4), ConstVariable(5)]) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 5, 4, 5, 6, 7]], + ) + var.setitem( + slice(4, 1, -1), + [ConstVariable(8), ConstVariable(9), ConstVariable(10)], + ) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 5, 10, 9, 8, 7]], + ) + + def test_setitem_slice_2(self): + data = [1, 2, 3, 4, 5, 6, 7] + var = ListVariable(data) + var.setitem(slice(2, 5, 2), [ConstVariable(8), ConstVariable(9)]) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [1, 2, 8, 4, 9, 6, 7]], + ) + + def test_delitem(self): + data = [1, 2, 3] + var = ListVariable(data) + var.delitem(0) + with self.assertRaises(IndexError): + var.getitem(2) + var.pop() + with self.assertRaises(IndexError): + var.getitem(1) + + def test_insert(self): + data = [1, 2, 3] + var = ListVariable(data) + var.insert(0, ConstVariable(4)) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 1, 2, 3]], + ) + var.insert(2, ConstVariable(5)) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [4, 1, 5, 2, 3]], + ) + + def test_append(self): + data = [1, 2, 3] + var = ListVariable(data) + var.append(ConstVariable(4)) + self.assertEqual(var.getitem(3), ConstVariable(4)) + + def test_extend(self): + data = [1, 2, 3] + var = ListVariable(data) + var.extend([ConstVariable(4), ConstVariable(5)]) + self.assertEqual(var.getitem(3), ConstVariable(4)) + self.assertEqual(var.getitem(4), ConstVariable(5)) + + def test_pop(self): + data = [1, 2, 3] + var = ListVariable(data) + self.assertEqual(var.pop(), ConstVariable(3)) + self.assertEqual(var.pop(0), ConstVariable(1)) + + def test_clear(self): + data = [1, 2, 3] + var = ListVariable(data) + var.clear() + self.assertEqual(var.proxy.length, 0) + + def test_remove(self): + data = [1, 2, 3] + var = ListVariable(data) + var.remove(ConstVariable(2)) + self.assertEqual(var.getitem(0), ConstVariable(1)) + self.assertEqual(var.getitem(1), ConstVariable(3)) + with self.assertRaises(ValueError): + var.remove(ConstVariable(2)) + + def test_sort(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.sort() + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [0, 1, 2, 3, 4, 5]], + ) + + def test_sort_with_key(self): + data = [-1, -4, 2, 0, 5, -3] + var = ListVariable(data) + var.sort(key=lambda x: x**2) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [0, -1, 2, -3, -4, 5]], + ) + + def test_sort_reverse(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.sort(reverse=True) + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [5, 4, 3, 2, 1, 0]], + ) + + def test_reverse(self): + data = [2, 3, 0, 4, 1, 5] + var = ListVariable(data) + var.reverse() + self.assertEqual( + [var.getitem(i) for i in range(var.proxy.length)], + [ConstVariable(n) for n in [5, 1, 4, 0, 3, 2]], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_numpy.py b/test/sot/test_numpy.py new file mode 100644 index 00000000000000..3600d4df7cc455 --- /dev/null +++ b/test/sot/test_numpy.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle + + +def foo(x, y): + ret = x + y + return ret + + +class TestNumpy(TestCaseBase): + def test_tensor_add_numpy_number(self): + x = paddle.to_tensor([1.0]) + y = np.int64(2) + self.assert_results(foo, x, y) + self.assert_results(foo, y, x) + + @strict_mode_guard(0) + def test_tensor_add_numpy_array(self): + x = paddle.to_tensor([1.0]) + y = np.array(2.0) + self.assert_results(foo, x, y) + self.assert_results(foo, y, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_numpy_var_if.py b/test/sot/test_numpy_var_if.py new file mode 100644 index 00000000000000..9d7c4a7048e251 --- /dev/null +++ b/test/sot/test_numpy_var_if.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.psdb import check_no_breakgraph, check_no_fallback + +os.environ['MIN_GRAPH_SIZE'] = '-1' + + +@check_no_breakgraph +@check_no_fallback +def forward(x, y): + if x == 0: + return y + 2 + else: + return y * 2 + + +@check_no_breakgraph +@check_no_fallback +def forward2(x, y): + if x == x: # numpy == numpy + return y + 2 + else: + return y * 2 + + +class TestJumpWithNumpy(TestCaseBase): + def test_jump(self): + self.assert_results(forward, np.array([1]), paddle.to_tensor(2)) + self.assert_results(forward, np.array([0]), paddle.to_tensor(2)) + self.assert_results(forward2, np.array([0]), paddle.to_tensor(2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_output_restoration.py b/test/sot/test_output_restoration.py new file mode 100644 index 00000000000000..9c2cf268e9087b --- /dev/null +++ b/test/sot/test_output_restoration.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def output_identity(x): + return x + + +def output_const(): + return 42 + + +def output_list(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + 1 + b = z + 1 + l = [1, a, b, y] + return l + + +def output_dict(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + 1 + b = z + 1 + l = {1: a, b: y} + return l + + +def output_dict_const_key(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + 1 + b = z + 1 + l = {1: a, 2: y} + return l + + +def output_nest_struct(x: paddle.Tensor, y: paddle.Tensor, z: int): + a = x + y + z + b = z + 1 + l = [1 + 1, (z, a), [b]] + return l + + +class TestOutputRestoration(TestCaseBase): + def test_output_identity(self): + self.assert_results(output_identity, 1) + self.assert_results(output_identity, 2) + self.assert_results(output_identity, paddle.to_tensor(1)) + + def test_output_const(self): + self.assert_results(output_const) + + def test_output_list(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + + self.assert_results(output_list, a, b, 3) + + def test_output_dict(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + + self.assert_results(output_dict, a, b, 3) + + def test_output_dict_const_key(self): + a = paddle.to_tensor(2) + b = paddle.to_tensor(3) + + self.assert_results(output_dict_const_key, a, b, 4) + + def test_output_nest_struct(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + + self.assert_results(output_nest_struct, a, b, 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_range.py b/test/sot/test_range.py new file mode 100644 index 00000000000000..3a7e85fb0951de --- /dev/null +++ b/test/sot/test_range.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def test_range_1(stop: int): + return range(stop) + + +def test_range_2(start: int, stop: int): + return range(start, stop) + + +def test_range_3(start: int, stop: int, step: int): + return range(start, stop, step) + + +def test_range_4(stop: int, index: int): + return range(stop)[index] + + +def test_range_5(stop: int): + return list(range(stop)) + + +def test_range_6(stop: int, index: int): + return list(range(stop))[index] + + +def test_range_7(index: int, tensor: paddle.Tensor): + return list(range(len(tensor.shape)))[index] + + +def test_range_8(stop: int): + sum = 0 + for i in range(stop): + sum += i + return sum + + +def test_range_9(stop: int, tensor: paddle.Tensor): + for i in range(stop): + tensor += i + return tensor + + +def test_range_10(stop: int, tensor: paddle.Tensor): + for i in range(stop): + for j in range(stop + 1): + tensor += j + return tensor + + +class TestExecutor(TestCaseBase): + def test_cases(self): + start = 3 + stop = 10 + step = 2 + index = 1 + tensor = paddle.randn((10, 10)) + + self.assert_results(test_range_1, stop) + self.assert_results(test_range_2, start, stop) + self.assert_results(test_range_3, start, stop, step) + self.assert_results(test_range_4, stop, index) + self.assert_results(test_range_5, stop) + self.assert_results(test_range_6, stop, index) + self.assert_results(test_range_7, index, tensor) + self.assert_results(test_range_8, stop) + + self.assert_results(test_range_9, stop, paddle.randn((10,))) + self.assert_results(test_range_10, stop, paddle.randn((10,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_resnet.py b/test/sot/test_resnet.py new file mode 100644 index 00000000000000..cc9a47252c559e --- /dev/null +++ b/test/sot/test_resnet.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +from paddle.vision.models.resnet import resnet18 + + +def resnet_call(x: paddle.Tensor, net: paddle.nn.Layer): + return net(x) + + +class TestResNet(TestCaseBase): + def test_resnet_eval(self): + x = paddle.rand((10, 3, 224, 224)) + net = resnet18(pretrained=False) + net.eval() + with test_instruction_translator_cache_context() as ctx: + self.assert_results(resnet_call, x, net) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(resnet_call, x, net) # cache hit + self.assertEqual(ctx.translate_count, 1) + net.train() + self.assert_results(resnet_call, x, net) # cache miss + self.assertEqual(ctx.translate_count, 2) + + def test_resnet_train(self): + x = paddle.rand((10, 3, 224, 224)) + net = resnet18(pretrained=False) + net.train() + with test_instruction_translator_cache_context() as ctx: + self.assert_results(resnet_call, x, net) + self.assertEqual(ctx.translate_count, 1) + self.assert_results(resnet_call, x, net) # cache hit + self.assertEqual(ctx.translate_count, 1) + net.eval() + self.assert_results(resnet_call, x, net) # cache miss + self.assertEqual(ctx.translate_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_resnet50_backward.py b/test/sot/test_resnet50_backward.py new file mode 100644 index 00000000000000..bd5aac0025e802 --- /dev/null +++ b/test/sot/test_resnet50_backward.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["FLAGS_cudnn_deterministic"] = "True" + +import random +import unittest + +import numpy as np +from numpy.testing import assert_array_equal + +import paddle +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.utils.utils import execute_time +from paddle.vision import resnet50 + + +def resnet_call(net: paddle.nn.Layer, x: paddle.Tensor): + return net(x) + + +def run_dygraph_optimizer(inp): + """dygraph train + SGD optimizer""" + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + net = resnet50() + optimizer = paddle.optimizer.SGD( + learning_rate=0.03, parameters=net.parameters() + ) + for i in range(5): + optimizer.clear_grad() + loss = execute_time(net)(inp) + loss.backward() + optimizer.step() + return loss + + +def run_symbolic_optimizer(inp): + """dygraph train + SGD optimizer""" + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + net = resnet50() + net_wrapper = symbolic_translate(resnet_call) + optimizer = paddle.optimizer.SGD( + learning_rate=0.03, parameters=net.parameters() + ) + for i in range(5): + optimizer.clear_grad() + loss = execute_time(net_wrapper)(net, inp) + loss.backward() + optimizer.step() + return loss + + +def run_to_static_optimizer(inp): + """dygraph train + SGD optimizer""" + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + net = resnet50() + net = paddle.jit.to_static(net, enable_fallback=False) + optimizer = paddle.optimizer.SGD( + learning_rate=0.03, parameters=net.parameters() + ) + for i in range(5): + optimizer.clear_grad() + loss = execute_time(net)(inp) + loss.backward() + optimizer.step() + return loss + + +class TestBackward(unittest.TestCase): + def test(self): + # TODO(xiongkun) add cache to speedup ! + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + inp = paddle.rand((3, 3, 255, 255)) + print("Start Run SymbolicTranslate:") + out2 = run_symbolic_optimizer(inp)[0].numpy() + print("Start Run Dygraph:") + out1 = run_dygraph_optimizer(inp)[0].numpy() + print("Start Run To Static:") + out1 = run_to_static_optimizer(inp)[0].numpy() + assert_array_equal( + out1, out2, "Not Equal in dygraph and static graph", True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_segment_linear.py b/test/sot/test_segment_linear.py new file mode 100644 index 00000000000000..ee3b7d70f8d365 --- /dev/null +++ b/test/sot/test_segment_linear.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle import nn +from paddle.jit import sot + + +class Head(nn.Layer): + def __init__(self): + super().__init__() + self.head = nn.Linear(10, 150) + + def forward(self, x, patch_embed_size): + masks = self.head(x) + # [b, (h w), c] -> [b, c, h, w] + h, w = patch_embed_size[0], patch_embed_size[1] + masks = masks.reshape((1, h, w, paddle.shape(masks)[-1])) + masks = masks.transpose((0, 3, 1, 2)) + return masks + + +class SimpleNet(nn.Layer): + def __init__(self): + super().__init__() + self.tmp = nn.Linear(1, 1024 * 10) + self.tmp2 = nn.Linear(1, 1 * 10 * 32 * 32) + self.head = Head() + + def getshape(self, x): + x = self.tmp2(x.mean().reshape([1])).reshape([1, 10, 32, 32]) + x = paddle.shape(x) + return x + + def forward(self, x): + shape = self.getshape(x) + feat = self.tmp(x.mean().reshape([1])).reshape([1, 1024, 10]) + logits = self.head(feat, shape[2:]) + return logits + + +class TestExecutor(TestCaseBase): + def test_simple(self): + sot.skip_function(SimpleNet.forward) + x = paddle.randn((1, 8, 8)) + net = SimpleNet() + net = paddle.jit.to_static( + net + ) # dont make effect. we need fetch sot PR in paddle. + loss = net(x) + loss = loss.sum() + loss.backward() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_side_effects.py b/test/sot/test_side_effects.py new file mode 100644 index 00000000000000..46bed6e8d3c4e3 --- /dev/null +++ b/test/sot/test_side_effects.py @@ -0,0 +1,333 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase, strict_mode_guard + +import paddle +from paddle.jit import sot +from paddle.jit.sot import symbolic_translate +from paddle.jit.sot.utils import InnerError + + +def dict_setitem(x): + x[0] = 1 + return x[0] + + +def dict_delitem(x): + del x[0] + return x + + +def dict_delitem_getitem(a): + b = a[0] + del a[0] + b[0] = 1 + return a, b + + +def dict_nested_1(x): + x[0][0] = 42 + x[1][0] = x[0][0] + x[0][1] + x[2] = {1: 2} + return x + + +def dict_nested_2(x): + a = x[0] + b = x[1] + del a[0] + a[1] = b[0] + a[2] = b[1] + x[1][0] = 42 + del a[1] + return a, b + + +def list_append_int(tensor_x, list_a): + tensor_x = tensor_x + 1 + list_a.append(12) + return tensor_x, list_a + + +def list_append_tensor(tensor_x, list_a): + tensor_x = tensor_x + 1 + list_a.append(tensor_x) + return tensor_x, list_a + + +def list_delitem(list_a): + del list_a[0] + return list_a[0] + + +def list_extend(list_a): + list_a.extend([1, 2, 3]) + return list_a[0] + + +def list_nested(list_a): + inner_list = [] + inner_list.append(list_a) + inner_list[-1].append(12) + return 12 + + +def list_insert(list_a): + list_a.insert(0, 1) + return list_a[0] + + +def list_remove(list_a): + list_a.remove(1) + return list_a[0] + + +def list_pop(list_a): + list_a.pop(0) + list_a.pop() + list_a.pop(1) + return list_a[0] + + +def list_clear(list_a): + list_a.clear() + return list_a + + +def list_sort(list_a): + list_a.sort() + return list_a + + +def list_reverse(list_a): + list_a.reverse() + return list_a + + +def slice_in_for_loop(x, iter_num=3): + x = paddle.to_tensor(x) + a = [] + + iter_num = paddle.full(shape=[1], fill_value=iter_num, dtype="int32") + + for i in range(iter_num): + a.append(x) + + for i in range(iter_num): + a[i] = x + out = a[2] + return out + + +# TODO: Object SideEffect +class CustomObject: + def __init__(self): + self.x = 2 + self.y = paddle.to_tensor(1) + + def object_attr_set2(self, x): + self.outputs = [] + self.outputs.append(x) + return self.outputs + + +@sot.psdb.check_no_breakgraph +def object_attr_set(cus_obj, t): + """object side effect.""" + t = t + 1 + cus_obj.x = t + return t, cus_obj.x + + +def object_attr_breakgraph(cus_obj, t): + t = t + 1 + sot.psdb.breakgraph() + cus_obj.x = t + sot.psdb.breakgraph() + return t, cus_obj.x + + +@sot.psdb.check_no_breakgraph +def object_attr_tensor_del(cus_obj): + del cus_obj.y + + +@sot.psdb.check_no_breakgraph +def object_attr_int_del(cus_obj): + del cus_obj.x + + +def slice_list_after_change(l): + l.reverse() + sum = 0 + for i, v in zip(range(2), l[2:]): + sum += v + return sum + + +class TestDictSideEffect(TestCaseBase): + def test_dict_setitem(self): + self.assert_results_with_side_effects( + dict_setitem, {0: paddle.to_tensor(0)} + ) + self.assert_results_with_side_effects( + dict_setitem, {0: paddle.to_tensor(1)} + ) + + def test_dict_delitem(self): + self.assert_results_with_side_effects( + dict_delitem, {0: paddle.to_tensor(0), 1: paddle.to_tensor(1)} + ) + self.assert_results_with_side_effects( + dict_delitem, {0: paddle.to_tensor(1), 2: paddle.to_tensor(2)} + ) + + def test_dict_delitem_getitem(self): + self.assert_results_with_side_effects( + dict_delitem_getitem, {0: {0: 1, 1: 2}} + ) + + def test_dict_nested_1(self): + self.assert_results_with_side_effects( + dict_nested_1, {0: {0: 1, 1: 2}, 1: {0: 1, 1: 2}} + ) + self.assert_results_with_side_effects( + dict_nested_1, {0: {0: 123, 1: 2}, 1: {0: 1, 1: 2}} + ) + + def test_dict_nested_2(self): + self.assert_results_with_side_effects( + dict_nested_2, {0: {0: 1, 1: 2}, 1: {0: 1, 1: 2}} + ) + self.assert_results_with_side_effects( + dict_nested_2, {0: {0: 123, 1: 2}, 1: {0: 1, 1: 2}} + ) + + +class TestListSideEffect(TestCaseBase): + def test_list_append(self): + self.assert_results_with_side_effects( + list_append_int, paddle.to_tensor(1), [1, 2, 3] + ) + self.assert_results_with_side_effects( + list_append_tensor, paddle.to_tensor(2), [1, 2, 3] + ) + + def test_list_delitem(self): + self.assert_results_with_side_effects(list_delitem, [1, 2, 3]) + + def test_list_extend(self): + self.assert_results_with_side_effects( + list_extend, [1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + + def test_list_insert(self): + self.assert_results_with_side_effects(list_insert, [1, 2, 3]) + self.assert_results_with_side_effects( + list_insert, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_remove(self): + self.assert_results_with_side_effects(list_remove, [1, 1, 1]) + self.assert_results_with_side_effects(list_remove, [0, 1, 2]) + with self.assertRaises(InnerError): + symbolic_translate(list_remove)([0, 2, 4]) + + def test_list_pop(self): + self.assert_results_with_side_effects(list_pop, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_pop, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_clear(self): + self.assert_results_with_side_effects(list_clear, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_clear, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_sort(self): + self.assert_results_with_side_effects(list_sort, [2, 1, 7, 3, 4, 6]) + self.assert_results_with_side_effects( + list_sort, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_list_reverse(self): + self.assert_results_with_side_effects(list_reverse, [1, 2, 3, 4, 5]) + self.assert_results_with_side_effects( + list_reverse, [-1, 2, -3, 4, -5, 6, -7, 8, -9] + ) + + def test_slice_in_for_loop(self): + x = 2 + with strict_mode_guard(0): + self.assert_results_with_side_effects(slice_in_for_loop, x) + + def test_list_nested(self): + self.assert_results_with_side_effects(list_nested, [1, 2, 3]) + + +class TestSliceAfterChange(TestCaseBase): + def test_slice_list_after_change(self): + self.assert_results_with_side_effects( + slice_list_after_change, [1, 2, 3, 4] + ) + self.assert_results_with_side_effects( + slice_list_after_change, [7, 8, 9, 10] + ) + + +class TestAttrSideEffect(TestCaseBase): + def attr_check(self, func, attr_keys: list[str], cls, *inputs): + cus_obj1 = cls() + cus_obj2 = cls() + sym_output = symbolic_translate(func)(cus_obj1, *inputs) + paddle_output = func(cus_obj2, *inputs) + for key in attr_keys: + self.assert_nest_match( + getattr(cus_obj1, key, f"__MISS_KEY__{key}"), + getattr(cus_obj2, key, f"__MISS_KEY__{key}"), + ) + self.assert_nest_match(sym_output, paddle_output) + + def test_attr_set(self): + self.attr_check(object_attr_set, ["x"], CustomObject, 5) + self.attr_check( + CustomObject.object_attr_set2, ["outputs"], CustomObject, 6 + ) + self.attr_check( + CustomObject.object_attr_set2, + ["outputs"], + CustomObject, + paddle.to_tensor(5), + ) + self.attr_check( + object_attr_set, ["x"], CustomObject, paddle.to_tensor(5) + ) + + def test_attr_del(self): + self.attr_check(object_attr_tensor_del, ["y"], CustomObject) + self.attr_check(object_attr_int_del, ["x"], CustomObject) + + def test_attr_set_breakgraph(self): + self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 100) + self.attr_check(object_attr_breakgraph, ["x"], CustomObject, 1000) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_simulate_initialize.py b/test/sot/test_simulate_initialize.py new file mode 100644 index 00000000000000..495e06ac1dbda2 --- /dev/null +++ b/test/sot/test_simulate_initialize.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle import nn +from paddle.jit.sot import symbolic_translate + + +class A: + def __init__(self, vals): + vals.append(1) + + +def foo(x, y): + out = nn.Softmax()(paddle.to_tensor([x, y], dtype="float32")) + return out + + +def bar(x): + a = A(x) + t = paddle.to_tensor(x) + return t.mean() + + +class TestInit(TestCaseBase): + def test_init_paddle_layer(self): + self.assert_results(foo, 1, 2) + + def test_init_python_object(self): + sot_output = symbolic_translate(bar)([1.0, 2.0]) + dyn_output = bar([1.0, 2.0]) + self.assert_nest_match(sot_output, dyn_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_sir_rollback.py b/test/sot/test_sir_rollback.py new file mode 100644 index 00000000000000..ddb7792651e4d1 --- /dev/null +++ b/test/sot/test_sir_rollback.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import operator +import unittest + +from test_case_base import TestCaseBase + +import paddle +from paddle.jit.sot.opcode_translator.executor.function_graph import ( + FunctionGraph, +) +from paddle.jit.sot.opcode_translator.executor.tracker import ( + DanglingTracker, + LocalTracker, +) +from paddle.jit.sot.opcode_translator.executor.variables import ( + BuiltinVariable, + VariableFactory, +) + + +def compute(x, y): + ret = BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) + return BuiltinVariable(operator.mul, x.graph, DanglingTracker())(ret, x) + + +def try_add(x, y): + return BuiltinVariable(operator.add, x.graph, DanglingTracker())(x, y) + + +class TestRollback(TestCaseBase): + def test_rollback(self): + frame = inspect.currentframe() + graph = FunctionGraph(frame) + a = paddle.to_tensor(1.0) + b = paddle.to_tensor(2.0) + a = VariableFactory().from_value(a, graph, LocalTracker("a")) + b = VariableFactory().from_value(b, graph, LocalTracker("b")) + out = compute(a, b) + original_length = len(graph.sir_ctx.TOS.statements) + memo = graph.save_memo() + try_add(out, out) + + assert len(graph.sir_ctx.TOS.statements) != len( + memo.stmt_ir.statements + ), "After add, we must statement IR." + graph.restore_memo(memo) + + assert len(graph.sir_ctx.TOS.statements) == original_length + + +def fn_with_side_effects_inner(x, y): + x[0] += 10 + x[1] += 20 + x[2] -= 10 + print(y) # print will cause breakgraph + + +def fn_with_side_effects(x, y): + x[0] += 1 + fn_with_side_effects_inner(x, y) + return x[0] + y + + +class TestSideEffectRollback(TestCaseBase): + def test_side_effect_rollback(self): + self.assert_results_with_side_effects( + fn_with_side_effects, [1, 2, 3], paddle.to_tensor(42) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_stack.py b/test/sot/test_stack.py new file mode 100644 index 00000000000000..e29610b2c837cf --- /dev/null +++ b/test/sot/test_stack.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.jit.sot.opcode_translator.executor.variable_stack import ( + VariableStack, +) + + +class TestVariableStack(unittest.TestCase): + def test_basic(self): + stack = VariableStack([1, 2, 3]) + self.assertEqual(str(stack), "[1, 2, 3]") + self.assertEqual(len(stack), 3) + self.assertEqual(str(stack.copy()), str(stack)) + + def test_peek(self): + stack = VariableStack([1, 2, 3]) + self.assertEqual(stack.peek(), 3) + self.assertEqual(stack.top, 3) + self.assertEqual(stack.peek(1), 3) + stack.peek[1] = 4 + stack.peek[2] = 3 + self.assertEqual(stack.peek[1], 4) + self.assertEqual(stack.peek[:1], [4]) + self.assertEqual(stack.peek[:2], [3, 4]) + stack.top = 5 + self.assertEqual(stack.peek[:2], [3, 5]) + + def test_push_pop(self): + stack = VariableStack() + stack.push(1) + stack.push(2) + self.assertEqual(stack.pop(), 2) + self.assertEqual(stack.pop(), 1) + + def test_pop_n(self): + stack = VariableStack([1, 2, 3, 4]) + self.assertEqual(stack.pop_n(2), [3, 4]) + self.assertEqual(stack.pop_n(2), [1, 2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_str_format.py b/test/sot/test_str_format.py new file mode 100644 index 00000000000000..34bbd6e31f3dde --- /dev/null +++ b/test/sot/test_str_format.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import TestCaseBase + + +# copy from python library _distutils_hack/__init__.py +def find_spec(self, fullname, path, target=None): + method_name = 'spec_for_{fullname}'.format( + **{'self': self, 'fullname': fullname} + ) + method = getattr(self, method_name, lambda: None) + return method() + + +class TestExecutor(TestCaseBase): + def test_simple(self): + self.assert_results(find_spec, "self", "fullname", "path", None) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_tensor_dtype_in_guard.py b/test/sot/test_tensor_dtype_in_guard.py new file mode 100644 index 00000000000000..d5d001b7038d0d --- /dev/null +++ b/test/sot/test_tensor_dtype_in_guard.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle +from paddle.jit import sot + + +def foo(x, y): + if x.dtype == paddle.float32: + out = x + y + else: + out = x - y + return out + + +@sot.skip_function +def dtype_in_guard(x, y): + with paddle.amp.auto_cast(level='O2'): + for i in range(10): + z = foo(x, y) + x = z + return x + + +def bar(x, y): + if x == paddle.float32: + return y + 1 + else: + return y - 1 + + +@sot.skip_function +def dtype_as_input(x, y): + with paddle.amp.auto_cast(level='O2'): + for i in range(10): + z = bar(x, y) + y = z + return y + + +class TestDtypeInGuard(TestCaseBase): + def test_dtype_in_guard(self): + with test_instruction_translator_cache_context() as ctx: + x = paddle.to_tensor([2], dtype="float32") + y = paddle.to_tensor([3], dtype="float32") + self.assert_results(dtype_in_guard, x, y) + self.assertEqual(ctx.translate_count, 1) + + def test_input_dtype_in_guard(self): + with test_instruction_translator_cache_context() as ctx: + x = paddle.float32 + y = paddle.to_tensor([3], dtype="float32") + self.assert_results(dtype_as_input, x, y) + self.assertEqual(ctx.translate_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_tensor_slice.py b/test/sot/test_tensor_slice.py new file mode 100644 index 00000000000000..32c52759da4387 --- /dev/null +++ b/test/sot/test_tensor_slice.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_case_base import TestCaseBase + +import paddle + + +def foo(x: paddle.Tensor): + return x[:, 0] + + +class TestExecutor(TestCaseBase): + def test_tensor_slice(self): + x = paddle.randn((10, 10)) + self.assert_results(foo, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sot/test_trace_list_arg.py b/test/sot/test_trace_list_arg.py new file mode 100644 index 00000000000000..8a82406a11f754 --- /dev/null +++ b/test/sot/test_trace_list_arg.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from test_case_base import ( + TestCaseBase, + test_instruction_translator_cache_context, +) + +import paddle + + +def foo(x: list[paddle.Tensor], y: list[paddle.Tensor]): + return x[0] + y[0] + + +def bar(x: list[paddle.Tensor], y: int, z: int): + return x[y + z] + 1 + + +class TestTraceListArg(TestCaseBase): + def test_foo(self): + a = paddle.to_tensor(1) + b = paddle.to_tensor(2) + c = paddle.to_tensor([3, 4]) + + with test_instruction_translator_cache_context() as cache: + self.assert_results(foo, [a], [b]) + self.assertEqual(cache.translate_count, 1) + self.assert_results(foo, [b], [a]) # Cache hit + self.assertEqual(cache.translate_count, 1) + self.assert_results(foo, [a], [c]) # Cache miss + self.assertEqual(cache.translate_count, 2) + + def test_bar(self): + a = [paddle.to_tensor(1), paddle.to_tensor(2), paddle.to_tensor(3)] + b = [paddle.to_tensor([2, 3]), paddle.to_tensor(4), paddle.to_tensor(5)] + + with test_instruction_translator_cache_context() as cache: + self.assert_results(bar, a, 1, 1) + self.assertEqual(cache.translate_count, 1) + self.assert_results(bar, a, 2, 0) # Cache miss + self.assertEqual(cache.translate_count, 2) + self.assert_results(bar, b, 1, 1) # Cache hit + self.assertEqual(cache.translate_count, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index cb33edca9dfbe8..dea0398f9d5fac 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -61,6 +61,7 @@ test_diag_v2 test_digamma_op test_dist_op test_dot_op +test_dpsgd_op test_edit_distance_op test_eigh_op test_eigh_op_static_build diff --git a/third_party/cccl b/third_party/cccl new file mode 160000 index 00000000000000..1f6e4bcae0fbf1 --- /dev/null +++ b/third_party/cccl @@ -0,0 +1 @@ +Subproject commit 1f6e4bcae0fbf1bbed87f88544d8d2161c490fc1 diff --git a/third_party/mkldnn b/third_party/mkldnn index 64f6bcbcbab628..01204edbda1c2a 160000 --- a/third_party/mkldnn +++ b/third_party/mkldnn @@ -1 +1 @@ -Subproject commit 64f6bcbcbab628e96f33a62c3e975f8535a7bde4 +Subproject commit 01204edbda1c2a4ff0cccd40476ed6bd2fb62d56 diff --git a/third_party/openblas b/third_party/openblas index 394a9fbafe9010..5f36f18148603f 160000 --- a/third_party/openblas +++ b/third_party/openblas @@ -1 +1 @@ -Subproject commit 394a9fbafe9010b76a2615c562204277a956eb52 +Subproject commit 5f36f18148603facb6c3540e673610d6b24cbfbb diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index 947d7bed767060..d1cb054771535b 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -29,6 +29,10 @@ function collect_failed_tests() { serial_list="^test_conv2d_op$|\ ^test_conv2d_transpose_op$|\ +^test_dygraph_dataparallel_bf16$|\ +^test_dygraph_sharding_stage1_fp16$|\ +^test_dygraph_sharding_stage2_bf16$|\ +^test_dygraph_sharding_stage3_bf16$|\ ^test_conv3d_op$" parallel_list="^init_phi_test$|\ diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index a89dafff96ab6b..b1a19e118e7e4b 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -284,7 +284,6 @@ 'test_depthwise_conv_mkldnn_pass', 'test_fleet_metric', 'test_fc_fuse_pass_cc', - 'test_fleet_private_function', 'test_fleet', 'test_executor_check_feed', 'test_py_reader_lod_level_share', @@ -2121,7 +2120,6 @@ 'test_dgc_optimizer', 'heter_server_test', 'test_custom_conj', - 'test_fleet_private_function', 'test_fake_init_op', 'brpc_service_sparse_sgd_test', 'test_tf32_cudnn', diff --git a/tools/test_runner.py b/tools/test_runner.py index 37d132fbc1535a..49603fd9a3afa5 100644 --- a/tools/test_runner.py +++ b/tools/test_runner.py @@ -40,6 +40,7 @@ def main(): sys.path.append(os.getcwd()) + os.environ["FLAGS_dynamic_static_unified_comm"] = "false" if core.is_compiled_with_cuda() or core.is_compiled_with_rocm(): if os.getenv('FLAGS_enable_gpu_memory_usage_log') is None: os.environ['FLAGS_enable_gpu_memory_usage_log'] = 'true'