diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index e8da46308c67da..d97dec43aca143 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -18,8 +18,6 @@ #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #endif -#include -#include #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/selected_rows.h" @@ -156,7 +154,6 @@ struct DeviceContext::Impl { ClearHolder(tensor); } } else { - VLOG(0) << "Segment Fault is about to come."; if (tensor->initialized() && tensor->place() != place) { ClearHolder(tensor); } diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 7a4bd57766c8ef..7c9287811396c5 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -1507,12 +1507,12 @@ void EmbeddingBagInferMeta(const MetaTensor& input, PADDLE_ENFORCE_EQ( ids_dims, weight_dims, - phi::errors::InvalidArgument( - "ShapeError: The shapes of 'input' and 'per_sample_weight' must be the same." - "But received input's shape = [%s]," - "per_sample_weight's shape = [%s].", - ids_dims, - weight_dims)); + phi::errors::InvalidArgument("ShapeError: The shapes of 'input' and " + "'per_sample_weight' must be the same." + "But received input's shape = [%s]," + "per_sample_weight's shape = [%s].", + ids_dims, + weight_dims)); PADDLE_ENFORCE_EQ( table_dims.size(), 2, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index b1fb8c2e9e143b..600d91bc50e2da 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -64,8 +64,8 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, MetaConfig config = MetaConfig()); void EmbeddingBagInferMeta(const MetaTensor& input, - const MetaTensor& params, const MetaTensor& weight, + const MetaTensor& per_sample_weight, MetaTensor* out); void DpsgdInferMeta(const MetaTensor& param, diff --git a/paddle/phi/kernels/cpu/embedding_bag_kernel.cc b/paddle/phi/kernels/cpu/embedding_bag_kernel.cc index 389f69e64f53d7..9b76015374afef 100644 --- a/paddle/phi/kernels/cpu/embedding_bag_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_bag_kernel.cc @@ -35,9 +35,11 @@ struct EmbeddingBagCPUFunctor { mode_(mode), out_(out) {} - using EigenArrayMap = Eigen::Map>; + using EigenArrayMap = + Eigen::Map>; using EigenVectorMap = Eigen::Map>; - using ConstEigenVectorMap = Eigen::Map>; + using ConstEigenVectorMap = + Eigen::Map>; using EigenIndex = Eigen::Index; template @@ -61,7 +63,8 @@ struct EmbeddingBagCPUFunctor { const ConstEigenVectorMap weight_slice( &weight_d[input_d[bag * sequence_length + seq] * output_dim], output_dim); - output_slice += weight_slice * per_sample_weight_d[bag * sequence_length + seq]; + output_slice += + weight_slice * per_sample_weight_d[bag * sequence_length + seq]; } if (mode_ == "mean") { output_slice /= static_cast(sequence_length); diff --git a/paddle/phi/kernels/embedding_bag_kernel.h b/paddle/phi/kernels/embedding_bag_kernel.h index d1467089cf837b..7baf7677ab5fc6 100644 --- a/paddle/phi/kernels/embedding_bag_kernel.h +++ b/paddle/phi/kernels/embedding_bag_kernel.h @@ -18,7 +18,7 @@ namespace phi { -enum class CalMode { ksum, kmean, kmax}; +enum class CalMode { ksum, kmean, kmax }; template void EmbeddingBagCUDAKernel(const Context& ctx, @@ -26,6 +26,6 @@ void EmbeddingBagCUDAKernel(const Context& ctx, const DenseTensor& weight, const DenseTensor& per_sample_weight, int64_t padding_idx, - const std::string &mode, - DenseTensor *out); + const std::string& mode, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/gpu/embedding_bag_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_bag_grad_kernel.cu index 1841a71894e9ab..e03d75c77f5abd 100644 --- a/paddle/phi/kernels/gpu/embedding_bag_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_bag_grad_kernel.cu @@ -16,14 +16,14 @@ #include #include #include -#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { -enum class CalMode_c { ksum, kmean, kmax}; +enum class CalMode_c { ksum, kmean, kmax }; // kernelfunc, calculate the grad of the variable 'weight' template @@ -56,7 +56,7 @@ __global__ void EmbeddingBagWeightsGrad(const int output_dim, } // asist in obtain the map between the indices and the rows of params // can refer 'index_vec' in embedding_bag_grad_kernel.cc(in line 83) -// +// template __global__ void PrepTempArraysKernel(const IdT *indices, IdT *sortedIndices, @@ -83,9 +83,10 @@ __global__ void EmbeddingBagParamsGrad(const int output_dim, const int feature_idx = threadIdx.x + bag_idx * blockDim.x; const int params_idx = __ldg(sortedIndices + sample_idx); // refer embeddingbag in tensorflow/addons, spin up a warp for each element - // of the indices array, having each warp check the previous element, + // of the indices array, having each warp check the previous element, // if the same, return without operations. If not, the warp iterates forward - // and accumulates gradient. The operation is to avoid repeated reads and writes + // and accumulates gradient. The operation is to avoid repeated reads and + // writes if (sample_idx > 0) { const int prev_idx = __ldg(sortedIndices + sample_idx - 1); if (prev_idx == params_idx) { @@ -192,10 +193,10 @@ struct EmbeddingBagGradCUDAFunctor { dim3 grids_2(total_blocks, 1, 1); - // the target of these operations is to avoid parallel writes to the same element of - // the grads. So 'PrepTempArraysKernel' is designed to pre-sorting a copy of the indices(sourtedIndices), - // and co-sorting a counter(sortedIndicesCounter). - + // the target of these operations is to avoid parallel writes to the same + // element of the grads. So 'PrepTempArraysKernel' is designed to + // pre-sorting a copy of the indices(sourtedIndices), and co-sorting a + // counter(sortedIndicesCounter). PrepTempArraysKernel <<>>( diff --git a/paddle/phi/kernels/gpu/embedding_bag_kernel.cu b/paddle/phi/kernels/gpu/embedding_bag_kernel.cu index 9cc490e3996228..b7ec93c1b07e8a 100644 --- a/paddle/phi/kernels/gpu/embedding_bag_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_bag_kernel.cu @@ -41,7 +41,7 @@ __global__ void EmbeddingBag(T *output, int padding_idx_count = 0; T sum = static_cast(0); T max_d = static_cast(0); - for (int j = 0; j < S; j ++) { + for (int j = 0; j < S; j++) { auto id = static_cast(ids[idy * S + j]); const T *tab = table + id * D; if (PaddingFlag && id == padding_idx) { @@ -57,8 +57,10 @@ __global__ void EmbeddingBag(T *output, if (mode == CalMode::ksum) { out[i] = sum; } else if (mode == CalMode::kmean) { - if (padding_idx_count == S) out[i] = static_cast(0); - else out[i] = sum / (S - padding_idx_count); + if (padding_idx_count == S) + out[i] = static_cast(0); + else + out[i] = sum / (S - padding_idx_count); } else { out[i] = max_d; } @@ -108,11 +110,29 @@ struct EmbeddingBagCUDAFunctor { if (mode_ == "max") mode_enum = CalMode::kmax; if (padding_idx_ == -1) { - EmbeddingBag<<>>( - output_d, weight_d, ids_d, per_sample_weight_d, N, K, D, S, padding_idx_, mode_enum); + EmbeddingBag + <<>>(output_d, + weight_d, + ids_d, + per_sample_weight_d, + N, + K, + D, + S, + padding_idx_, + mode_enum); } else { - EmbeddingBag<<>>( - output_d, weight_d, ids_d, per_sample_weight_d, N, K, D, S, padding_idx_, mode_enum); + EmbeddingBag + <<>>(output_d, + weight_d, + ids_d, + per_sample_weight_d, + N, + K, + D, + S, + padding_idx_, + mode_enum); } } @@ -121,7 +141,7 @@ struct EmbeddingBagCUDAFunctor { const DenseTensor &input_; const DenseTensor &weight_; const DenseTensor &per_sample_weight_; - const std::string& mode_; + const std::string &mode_; const int64_t padding_idx_; DenseTensor *out_; }; diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index b93c865f66723c..c5cbaf9b429d85 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -254,7 +254,15 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): return tmp -def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode="sum", sparse=False, name=None): +def embedding_bag( + input, + weight, + per_sample_weight=None, + padding_idx=None, + mode="sum", + sparse=False, + name=None, +): """ Used to calculate the sum ,mean, or max of the specified bag in the embeddings vector by : attr:'input'. Each bag contains several row indexes of embeddings. @@ -289,9 +297,9 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode= such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`. In these cases, sparse must be False. Default: False. padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]). - If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted - to :math:`weight.shape[0] + padding\_idx` . It will output all-zero padding data whenever lookup - encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. + If :math:`padding_idx < 0`, the :math:`padding_idx` will automatically be converted + to :math:`weight.shape[0] + padding_idx` . It will output all-zero padding data whenever lookup + encounters :math:`padding_idx` in id. And the padding data will not be updated while training. If set None, it makes no effect to output. Default: None. mode(str): Specifies the way to reduce the bag. "sum" computes the weighted sum, taking weight into consideration. "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean" @@ -308,7 +316,7 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode= >>> input = np.random.randint(low=0, high=10, size = (2,6)).astype(np.int64) >>> input = paddle.to_tensor(input, stop_gradient = False) >>> per_sample_weight = np.random.random((2,6)).astype(np.float32) - >>> per_sample_weight = paddle.to_tensor(per_sample_weight, stop_gradient = False) + >>> per_sample_weight = paddle.to_tensor(per_sample_weight, stop_gradient = False) >>> weight = np.random.random((10,3)).astype(np.float32) >>> weight = paddle.to_tensor(weight, stop_gradient = False) >>> sum = nn.functional.embedding_bag(input, weight, per_sample_weight, mode='sum') @@ -330,7 +338,9 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode= ) if in_dynamic_or_pir_mode(): - return _C_ops.embedding_bag(input, weight, per_sample_weight, padding_idx, mode, sparse, name) + return _C_ops.embedding_bag( + input, weight, per_sample_weight, padding_idx, mode, sparse, name + ) else: helper = LayerHelper('embedding_bag', **locals()) dtype = helper.input_dtype(input_param_name='weight') @@ -360,7 +370,7 @@ def embedding_bag(input, weight, per_sample_weight=None, padding_idx=None, mode= 'is_distributed': is_distributed, 'remote_prefetch': remote_prefetch, 'padding_idx': padding_idx, - 'mode': mode + 'mode': mode, }, ) return tmp diff --git a/test/legacy_test/test_embeddingbag_op.py b/test/legacy_test/test_embeddingbag_op.py index 6c6519c01e60b5..80ad296faae3d2 100644 --- a/test/legacy_test/test_embeddingbag_op.py +++ b/test/legacy_test/test_embeddingbag_op.py @@ -39,7 +39,9 @@ def manual_embeddingbag(input, params, weights=None, mode="sum"): def get_input(rows=5, cols=3, num_embeddings=10): a = np.random.choice(np.arange(num_embeddings), size=cols, replace=False) for _ in range(rows - 1): - b = np.random.choice(np.arange(num_embeddings), size=cols, replace=False) + b = np.random.choice( + np.arange(num_embeddings), size=cols, replace=False + ) a = np.vstack((a, b)) return a @@ -53,11 +55,15 @@ def setUp(self): self.python_api = paddle.nn.functional.embedding_bag weight = np.random.random((20, 64)).astype(self.dtype) input = get_input(10, 20, weight.shape[0]) - per_sample_weight = np.random.randint(low=0, high=10, size=input.shape).astype( - np.float64 - ) - - self.inputs = {'input': input, 'weight': weight, 'per_sample_weight': per_sample_weight} + per_sample_weight = np.random.randint( + low=0, high=10, size=input.shape + ).astype(np.float64) + + self.inputs = { + 'input': input, + 'weight': weight, + 'per_sample_weight': per_sample_weight, + } np_out = manual_embeddingbag(input, weight, per_sample_weight) self.outputs = { 'out': np_out.reshape((input.shape[0], weight.shape[1]))