From ee348bbfe5035d1268861cbfc5cfca62c337949d Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 25 Oct 2023 04:43:37 +0000 Subject: [PATCH 1/5] support disttensor for tensor.copy_ --- paddle/phi/api/lib/tensor_method.cc | 70 ++++++++++++++- .../semi_auto_parallel_recompute.py | 89 +++++++++++++++++++ .../test_semi_auto_parallel_recompute.py | 57 ++++++++++++ 3 files changed, 212 insertions(+), 4 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_recompute.py create mode 100644 test/auto_parallel/test_semi_auto_parallel_recompute.py diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 74ee1e380dcc4a..59737c3aed2ee7 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -27,7 +27,11 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/infermeta/unary.h" // clang-format off - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/api/lib/data_transform.h" +#endif namespace paddle { namespace experimental { // declare cast api @@ -87,9 +91,7 @@ void Tensor::copy_(const Tensor &src, VLOG(8) << "Src is empty, skip copy"; return; } - // Prepare copy kernel key and outputs - auto kernel_key_set = ParseKernelKeyByInputArgs(src); - KernelType kernel_type = ParseKernelTypeByInputArgs(src); + VLOG(3) << "Deep copy Tensor from " << src.name() << " to " << name(); if (initialized()) { PADDLE_ENFORCE_EQ(dtype(), @@ -114,6 +116,12 @@ void Tensor::copy_(const Tensor &src, "Copy cannot be performed!", target_place, place())); + } + + // Prepare copy kernel key and outputs + auto kernel_key_set = ParseKernelKeyByInputArgs(src); + KernelType kernel_type = ParseKernelTypeByInputArgs(src); + if (initialized()) { kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place())); } else { @@ -128,6 +136,60 @@ void Tensor::copy_(const Tensor &src, auto *dev_ctx = pool.GetMutable( place.GetType() == target_place.GetType() ? target_place : place); +// #ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(src); + bool rank_is_in_current_mesh = false; + if (run_auto_parallel) { + auto mesh = std::static_pointer_cast( + src.impl())->dist_attr().process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + // 1. InferSpmd (Infer DistAttr of Inputs&Outputs) + auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); + auto spmd_info = phi::distributed::ElementwiseUnaryInferSpmd( + meta_dist_input_x); + + // 2. Create API Output & Prepare Dist and Dense Output + auto dist_out = SetKernelDistOutput(this, spmd_info.second[0]); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = phi::DenseTensor( + std::make_shared(nullptr, + 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + // 3. Infer DistTensor's Global Shape + phi::MetaTensor meta_dist_out(dist_out); + phi::UnchangedInferMeta(MakeMetaTensor(*(src.impl_)), &meta_dist_out); + + if (rank_is_in_current_mesh) { + // 4. Select Kernel + + // 5. Reshard Input + auto dist_input_x = ReshardApiInputToKernelInput( + dev_ctx, src, spmd_info.first[0]); + + // 6. PrepareData (DataTransform & Prepare Dense Input) + auto input_x = &dist_input_x->value(); + + // 7. Infer Local DenseTensor Meta + phi::MetaTensor meta_dense_out(dense_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_x), &meta_dense_out); + + // 8. DenseTensor Kernel Call + phi::Copy(*dev_ctx, *input_x, target_place, blocking, dense_out); + + // 9. Reshard Partial Output to Replicated (Temporary) + ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out); + } + + // 10. Set Output Dist Attr For Default Impl + // API `copy_` does not need to set DistAttr for output. + return; + } +// #endif + if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { SetKernelOutput(this); phi::MetaTensor meta_out(impl_.get()); diff --git a/test/auto_parallel/semi_auto_parallel_recompute.py b/test/auto_parallel/semi_auto_parallel_recompute.py new file mode 100644 index 00000000000000..190181549a918c --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_recompute.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed.fleet.utils import recompute + +BATCH_SIZE = 16 +BATCH_NUM = 4 +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +class MPDemoNet(nn.Layer): + def __init__(self, np_w0, np_w1, mesh, param_suffix=""): + super().__init__() + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="mp_demo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w0), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']), + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="mp_nemo_weight_2" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w1), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]), + ) + + def _inner_forward_fn(self, x): + y = paddle.matmul(x, self.w0) + z = paddle.matmul(y, self.w1) + return z + + def forward(self, x): + z = recompute(self._inner_forward_fn, x) + return z + + +def run_dynamic(layer, image, label): + # create loss + loss_fn = nn.MSELoss() + # run forward and backward + image = paddle.to_tensor(image) + image.stop_gradient = False + out = layer(image) + + label = paddle.to_tensor(label) + loss = loss_fn(out, label) + + loss.backward() + return loss, layer.w0.grad, layer.w1.grad + + +class TestSemiAutoParallelRecompute: + def test_recompute(): + mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + image = np.random.random([BATCH_SIZE, IMAGE_SIZE]).astype('float32') + label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32') + w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') + w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') + run_dynamic(layer=MPDemoNet(w0, w1, mesh), image=image, label=label) + + +if __name__ == "__main__": + TestSemiAutoParallelRecompute.test_recompute() diff --git a/test/auto_parallel/test_semi_auto_parallel_recompute.py b/test/auto_parallel/test_semi_auto_parallel_recompute.py new file mode 100644 index 00000000000000..2582b8bf55ded6 --- /dev/null +++ b/test/auto_parallel/test_semi_auto_parallel_recompute.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TestSemiAutoParallelRecompute(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120, nnode=2) + self._default_envs = { + "dtype": "float32", + "seed": "2023", + } + self._changeable_envs = { + "backend": ["gpu"], + } + + def test_simple_net_bybrid_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_recompute.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() From 9554f34141783e860548df36871db0effce616ff Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 25 Oct 2023 04:43:59 +0000 Subject: [PATCH 2/5] support disttensor for tensor.copy_ --- paddle/phi/api/lib/tensor_method.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 59737c3aed2ee7..1a11aab785352f 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -136,7 +136,7 @@ void Tensor::copy_(const Tensor &src, auto *dev_ctx = pool.GetMutable( place.GetType() == target_place.GetType() ? target_place : place); -// #ifdef PADDLE_WITH_DISTRIBUTE +#ifdef PADDLE_WITH_DISTRIBUTE bool run_auto_parallel = AllInputsAreDistTensor(src); bool rank_is_in_current_mesh = false; if (run_auto_parallel) { @@ -188,7 +188,7 @@ void Tensor::copy_(const Tensor &src, // API `copy_` does not need to set DistAttr for output. return; } -// #endif +#endif if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { SetKernelOutput(this); From 4e252abeee8aa608d6d5076eb6a6d2255a92764a Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 25 Oct 2023 06:19:48 +0000 Subject: [PATCH 3/5] refine --- test/auto_parallel/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index a5513adafb6cc9..086ad4465d0907 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -127,6 +127,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_semi_auto_parallel_hybrid_strategy) set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_semi_auto_parallel_recompute MODULES + test_semi_auto_parallel_recompute) + set_tests_properties(test_semi_auto_parallel_recompute + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_gpt_with_newir MODULES test_gpt_with_newir) set_tests_properties(test_gpt_with_newir PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) From 015b19524fb76ff0ef5eba3b28e6a47eb09fa1eb Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 25 Oct 2023 07:39:09 +0000 Subject: [PATCH 4/5] refine --- test/auto_parallel/CMakeLists.txt | 4 -- .../semi_auto_parallel_recompute.py | 40 ++----------- .../semi_auto_parallel_simple_net.py | 35 ++++++++++++ .../test_semi_auto_parallel_recompute.py | 57 ------------------- ...test_semi_auto_parallel_single_strategy.py | 10 ++++ 5 files changed, 49 insertions(+), 97 deletions(-) delete mode 100644 test/auto_parallel/test_semi_auto_parallel_recompute.py diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 086ad4465d0907..a5513adafb6cc9 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -127,10 +127,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) test_semi_auto_parallel_hybrid_strategy) set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) - py_test_modules(test_semi_auto_parallel_recompute MODULES - test_semi_auto_parallel_recompute) - set_tests_properties(test_semi_auto_parallel_recompute - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_gpt_with_newir MODULES test_gpt_with_newir) set_tests_properties(test_gpt_with_newir PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) diff --git a/test/auto_parallel/semi_auto_parallel_recompute.py b/test/auto_parallel/semi_auto_parallel_recompute.py index 190181549a918c..7329a1f4d0bafb 100644 --- a/test/auto_parallel/semi_auto_parallel_recompute.py +++ b/test/auto_parallel/semi_auto_parallel_recompute.py @@ -14,11 +14,11 @@ import numpy as np +from semi_auto_parallel_simple_net import MPDemoNetRecompute import paddle import paddle.distributed as dist from paddle import nn -from paddle.distributed.fleet.utils import recompute BATCH_SIZE = 16 BATCH_NUM = 4 @@ -26,40 +26,6 @@ CLASS_NUM = 10 -class MPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh, param_suffix=""): - super().__init__() - self.w0 = dist.shard_tensor( - self.create_parameter( - shape=[IMAGE_SIZE, IMAGE_SIZE], - attr=paddle.framework.ParamAttr( - name="mp_demo_weight_1" + param_suffix, - initializer=paddle.nn.initializer.Assign(np_w0), - ), - ), - dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']), - ) - self.w1 = dist.shard_tensor( - self.create_parameter( - shape=[IMAGE_SIZE, CLASS_NUM], - attr=paddle.framework.ParamAttr( - name="mp_nemo_weight_2" + param_suffix, - initializer=paddle.nn.initializer.Assign(np_w1), - ), - ), - dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]), - ) - - def _inner_forward_fn(self, x): - y = paddle.matmul(x, self.w0) - z = paddle.matmul(y, self.w1) - return z - - def forward(self, x): - z = recompute(self._inner_forward_fn, x) - return z - - def run_dynamic(layer, image, label): # create loss loss_fn = nn.MSELoss() @@ -82,7 +48,9 @@ def test_recompute(): label = np.random.random([BATCH_SIZE, CLASS_NUM]).astype('float32') w0 = np.random.random([IMAGE_SIZE, IMAGE_SIZE]).astype('float32') w1 = np.random.random([IMAGE_SIZE, CLASS_NUM]).astype('float32') - run_dynamic(layer=MPDemoNet(w0, w1, mesh), image=image, label=label) + run_dynamic( + layer=MPDemoNetRecompute(w0, w1, mesh), image=image, label=label + ) if __name__ == "__main__": diff --git a/test/auto_parallel/semi_auto_parallel_simple_net.py b/test/auto_parallel/semi_auto_parallel_simple_net.py index fb7d0b4406697d..6aa1090b2787aa 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net.py @@ -19,6 +19,7 @@ import paddle import paddle.distributed as dist from paddle import nn +from paddle.distributed.fleet.utils import recompute BATCH_SIZE = 16 BATCH_NUM = 4 @@ -114,6 +115,40 @@ def forward(self, x): return z +class MPDemoNetRecompute(nn.Layer): + def __init__(self, np_w0, np_w1, mesh, param_suffix=""): + super().__init__() + self.w0 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, IMAGE_SIZE], + attr=paddle.framework.ParamAttr( + name="mp_demo_weight_1" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w0), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=[None, 'x']), + ) + self.w1 = dist.shard_tensor( + self.create_parameter( + shape=[IMAGE_SIZE, CLASS_NUM], + attr=paddle.framework.ParamAttr( + name="mp_nemo_weight_2" + param_suffix, + initializer=paddle.nn.initializer.Assign(np_w1), + ), + ), + dist_attr=dist.DistAttr(mesh=mesh, sharding_specs=['x', None]), + ) + + def _inner_forward_fn(self, x): + y = paddle.matmul(x, self.w0) + z = paddle.matmul(y, self.w1) + return z + + def forward(self, x): + z = recompute(self._inner_forward_fn, x) + return z + + class PPDemoNet(nn.Layer): def __init__(self, np_w0, np_w1, mesh0, mesh1): super().__init__() diff --git a/test/auto_parallel/test_semi_auto_parallel_recompute.py b/test/auto_parallel/test_semi_auto_parallel_recompute.py deleted file mode 100644 index 2582b8bf55ded6..00000000000000 --- a/test/auto_parallel/test_semi_auto_parallel_recompute.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import collective.test_communication_api_base as test_base - -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TestSemiAutoParallelRecompute(test_base.CommunicationTestDistBase): - def setUp(self): - super().setUp(num_of_devices=2, timeout=120, nnode=2) - self._default_envs = { - "dtype": "float32", - "seed": "2023", - } - self._changeable_envs = { - "backend": ["gpu"], - } - - def test_simple_net_bybrid_strategy(self): - envs_list = test_base.gen_product_envs_list( - self._default_envs, self._changeable_envs - ) - for envs in envs_list: - self.run_test_case( - "semi_auto_parallel_recompute.py", - user_defined_envs=envs, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 03b31f70a9e9b3..42456d2d737ce6 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -50,6 +50,16 @@ def test_simple_net_single_strategy_with_amp(self): user_defined_envs=envs, ) + def test_simple_net_recompute(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_recompute.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main() From efc708d9278c8a29d07eb78b82215899e2d8d494 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Wed, 25 Oct 2023 08:29:34 +0000 Subject: [PATCH 5/5] refine --- paddle/phi/api/lib/tensor_method.cc | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 1a11aab785352f..deebdbe0019ee7 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -136,6 +136,7 @@ void Tensor::copy_(const Tensor &src, auto *dev_ctx = pool.GetMutable( place.GetType() == target_place.GetType() ? target_place : place); + if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { #ifdef PADDLE_WITH_DISTRIBUTE bool run_auto_parallel = AllInputsAreDistTensor(src); bool rank_is_in_current_mesh = false; @@ -146,11 +147,9 @@ void Tensor::copy_(const Tensor &src, // 1. InferSpmd (Infer DistAttr of Inputs&Outputs) auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); - auto spmd_info = phi::distributed::ElementwiseUnaryInferSpmd( - meta_dist_input_x); // 2. Create API Output & Prepare Dist and Dense Output - auto dist_out = SetKernelDistOutput(this, spmd_info.second[0]); + auto dist_out = SetKernelDistOutput(this, meta_dist_input_x.dist_attr()); auto dense_out = dist_out->unsafe_mutable_value(); if (!rank_is_in_current_mesh) { *dense_out = phi::DenseTensor( @@ -167,8 +166,8 @@ void Tensor::copy_(const Tensor &src, // 4. Select Kernel // 5. Reshard Input - auto dist_input_x = ReshardApiInputToKernelInput( - dev_ctx, src, spmd_info.first[0]); + auto dist_input_x = static_cast( + src.impl().get());; // 6. PrepareData (DataTransform & Prepare Dense Input) auto input_x = &dist_input_x->value(); @@ -181,7 +180,6 @@ void Tensor::copy_(const Tensor &src, phi::Copy(*dev_ctx, *input_x, target_place, blocking, dense_out); // 9. Reshard Partial Output to Replicated (Temporary) - ReshardOutputPartialAxisToReplicated(dev_ctx, dist_out); } // 10. Set Output Dist Attr For Default Impl @@ -189,8 +187,6 @@ void Tensor::copy_(const Tensor &src, return; } #endif - - if (kernel_type == KernelType::DENSE_TENSOR_KENREL) { SetKernelOutput(this); phi::MetaTensor meta_out(impl_.get()); phi::UnchangedInferMeta(