Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.4】为 Paddle 新增 masked_scatter API #6405

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/paddle/Overview_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ tensor 数学操作原位(inplace)版本
" :ref:`paddle.hypot_ <cn_api_paddle_hypot_>` ", "Inplace 版本的 hypot API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.multigammaln_ <cn_api_paddle_multigammaln_>` ", "Inplace 版本的 multigammaln API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.masked_fill_ <cn_api_paddle_masked_fill_>` ", "Inplace 版本的 masked_fill API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.masked_scatter_ <cn_api_paddle_masked_scatter_>` ", "Inplace 版本的 masked_scatter API,对输入 x 采用 Inplace 策略"
" :ref:`paddle.index_fill_ <cn_api_paddle_index_fill_>` ", "Inplace 版本的 index_fill API,对输入 x 采用 Inplace 策略"

.. _tensor_logic:
Expand Down Expand Up @@ -401,6 +402,7 @@ tensor 元素操作相关(如:转置,reshape 等)
" :ref:`paddle.view_as <cn_api_paddle_view_as>` ", "使用 other 的 shape,返回 x 的一个 view Tensor"
" :ref:`paddle.unfold <cn_api_paddle_unfold>` ", "返回 x 的一个 view Tensor。以滑动窗口式提取 x 的值"
" :ref:`paddle.masked_fill <cn_api_paddle_masked_fill>` ", "根据 mask 信息,将 value 中的值填充到 x 中 mask 对应为 True 的位置。"
" :ref:`paddle.masked_scatter <cn_api_paddle_masked_scatter>` ", "根据 mask 信息,将 value 中的值逐个填充到 x 中 mask 对应为 True 的位置。"
" :ref:`paddle.diagonal_scatter <cn_api_paddle_diagonal_scatter>` ", "根据给定的轴 axis 和偏移量 offset,将张量 y 的值填充到张量 x 中"
" :ref:`paddle.index_fill <cn_api_paddle_index_fill>` ", "沿着指定轴 axis 将 index 中指定位置的 x 的值填充为 value"
" :ref:`paddle.column_stack <cn_api_paddle_column_stack>` ", "沿水平轴堆叠输入 x 中的所有张量。"
Expand Down
13 changes: 13 additions & 0 deletions docs/api/paddle/Tensor_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3121,6 +3121,19 @@ masked_fill_(x, mask, value, name=None)

Inplace 版本的 :ref:`cn_api_paddle_masked_fill` API,对输入 `x` 采用 Inplace 策略。

masked_scatter(x, mask, value, name=None)
:::::::::
根据 mask 信息,将 value 中的值逐个填充到 x 中 mask 对应为 True 的位置。

返回一个根据 mask 将对应位置填充为 value 中元素的 Tensor。

请参考 :ref:`cn_api_paddle_masked_scatter`

masked_scatter_(x, mask, value, name=None)
:::::::::

Inplace 版本的 :ref:`cn_api_paddle_masked_scatter` API,对输入 `x` 采用 Inplace 策略。

atleast_1d(name=None)
:::::::::
将输入转换为张量并返回至少为 ``1`` 维的视图。 ``1`` 维或更高维的输入会被保留。
Expand Down
11 changes: 11 additions & 0 deletions docs/api/paddle/masked_scatter__cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _cn_api_paddle_masked_scatter_:

masked_scatter\_
-------------------------------

.. py:function:: paddle.masked_scatter_(x, mask, value, name=None)
Inplace 版本的 :ref:`cn_api_paddle_masked_scatter` API,对输入 x 采用 Inplace 策略。

更多关于 inplace 操作的介绍请参考 `3.1.3 原位(Inplace)操作和非原位操作的区别`_ 了解详情。

.. _3.1.3 原位(Inplace)操作和非原位操作的区别: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/beginner/tensor_cn.html#id3
28 changes: 28 additions & 0 deletions docs/api/paddle/masked_scatter_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. _cn_api_paddle_masked_scatter:

masked_scatter
-------------------------------

.. py:function:: paddle.masked_scatter(x, mask, value, name=None)



返回一个 N-D 的 Tensor,Tensor 的值是根据 ``mask`` 信息,将 ``value`` 中的值逐个填充到 ``x`` 中 ``mask`` 对应为 ``True`` 的位置,``mask`` 的数据类型是 bool。

参数
::::::::::::

- **x** (Tensor) - 输入 Tensor,数据类型为 float,double,int,int64_t,float16 或者 bfloat16。
- **mask** (Tensor) - 布尔张量,表示要填充的位置。mask 的数据类型必须为 bool。
- **value** (Tensor) - 用于填充目标张量的值,数据类型为 float,double,int,int64_t,float16 或者 bfloat16。
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。

返回
::::::::::::
返回一个根据 ``mask`` 将对应位置逐个填充 ``value`` 中的 Tensor,数据类型与 ``x`` 相同。


代码示例
::::::::::::

COPY-FROM: paddle.masked_scatter
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## [ 参数完全一致 ] torch.Tensor.masked_scatter

### [torch.Tensor.masked_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter.html?highlight=masked_scatter#torch.Tensor.masked_scatter)

```python
torch.Tensor.masked_scatter(mask, value)
```

### [paddle.Tensor.masked_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#masked-scatter-mask-value-name-non)

```python
paddle.Tensor.masked_scatter(mask, value, name=None)
```

两者功能一致,参数完全一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| mask | mask | 布尔张量,表示要填充的位置 |
| value | value | 用于填充目标张量的值 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## [ 参数完全一致 ] torch.Tensor.masked_scatter_

### [torch.Tensor.masked_scatter_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html?highlight=masked_scatter#torch.Tensor.masked_scatter_)

```python
torch.Tensor.masked_scatter_(mask, value)
```

### [paddle.Tensor.masked_scatter_]()

```python
paddle.Tensor.masked_scatter_(mask, value, name=None)
```

两者功能一致,参数完全一致,具体如下:

### 参数映射

| PyTorch | PaddlePaddle | 备注 |
|---------|--------------| -------------------------------------------------- |
| mask | mask | 布尔张量,表示要填充的位置 |
| value | value | 用于填充目标张量的值 |
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,8 @@
| 304 | [torch.Tensor.resize_](https://pytorch.org/docs/stable/generated/torch.Tensor.resize_.html?highlight=resize#torch.Tensor.resize_) | | 功能缺失 |
| 305 | [torch.Tensor.masked_fill_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html?highlight=resize#torch.Tensor.masked_fill_) | [paddle.Tensor.masked_fill_](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#id25) | 功能完全一致 |
| 306 | [torch.Tensor.tensor_split](https://pytorch.org/docs/stable/generated/torch.Tensor.tensor_split.html) | [paddle.Tensor.tensor_split](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#tensor_split-indices_or_sections-axis-0-name-none) | 功能完全一致,仅参数名不一致 [差异对比](https://github.com/PaddlePaddle/docs/tree/develop/docs/guides/model_convert/convert_from_pytorch/api_difference/Tensor/torch.Tensor.tensor_split.md) |
| 307 | [torch.Tensor.masked_scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter.html?highlight=resize#torch.Tensor.masked_scatter) | [paddle.Tensor.masked_scatter](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#id25) | 功能完全一致 |
| 308 | [torch.Tensor.masked_scatter_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_scatter_.html?highlight=resize#torch.Tensor.masked_scatter_) | [paddle.Tensor.masked_scatter_](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Tensor_cn.html#id25) | 功能完全一致 |


| 序号 | PyTorch API | PaddlePaddle API | 备注 |
Expand Down