Skip to content

Commit

Permalink
tenstorrent#11086: Update supported params for all ops in unary backw…
Browse files Browse the repository at this point in the history
…ard doc (tenstorrent#14376)

* tenstorrent#11086: Update bind_unary_backward_op

* tenstorrent#11086: Update remaining ops

* tenstorrent#14782: Restructure doc for supported params

---------

Co-authored-by: VirdhatchaniKN <virdhatchani.narayanamoorthy@multicorewareinc.com>
  • Loading branch information
2 people authored and Christopher Taylor committed Nov 12, 2024
1 parent a2b9e1a commit e4c3dcc
Showing 1 changed file with 114 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ void bind_unary_backward_rsqrt(
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (uint8, optional): command queue id. Defaults to `0`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{3}
Expand Down Expand Up @@ -296,7 +299,14 @@ void bind_unary_backward_op_overload_abs(
}

template <typename unary_backward_operation_t>
void bind_unary_backward_float(py::module& module, const unary_backward_operation_t& operation, const std::string& description, const std::string& parameter_name_a, const std::string& parameter_a_doc) {
void bind_unary_backward_float(
py::module& module,
const unary_backward_operation_t& operation,
const std::string& description,
const std::string& parameter_name_a,
const std::string& parameter_a_doc,
const std::string& supported_dtype = "BFLOAT16",
const std::string& note="") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -312,6 +322,21 @@ void bind_unary_backward_float(py::module& module, const unary_backward_operatio
Returns:
List of ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {5}
- TILE
- 2, 3, 4
{6}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
Expand All @@ -322,7 +347,9 @@ void bind_unary_backward_float(py::module& module, const unary_backward_operatio
operation.python_fully_qualified_name(),
description,
parameter_name_a,
parameter_a_doc);
parameter_a_doc,
supported_dtype,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -485,7 +512,9 @@ void bind_unary_backward_optional_float_params_with_default(
const std::string& parameter_name_b,
const std::string& parameter_b_doc,
std::optional<float> parameter_b_value,
const std::string& description) {
const std::string& description,
const std::string& suported_dtype = "BFLOAT16",
const std::string& note = "") {
auto doc = fmt::format(
R"doc(
{8}
Expand All @@ -502,6 +531,22 @@ void bind_unary_backward_optional_float_params_with_default(
Returns:
List of ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {9}
- TILE
- 2, 3, 4
{10}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
Expand All @@ -516,7 +561,9 @@ void bind_unary_backward_optional_float_params_with_default(
parameter_name_b,
parameter_b_doc,
parameter_b_value,
description);
description,
suported_dtype,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -786,14 +833,14 @@ void bind_unary_backward_shape(

template <typename unary_backward_operation_t>
void bind_unary_backward_optional(
py::module& module, const unary_backward_operation_t& operation, const std::string_view description) {
py::module& module, const unary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype = "BFLOAT16", const std::string& layout = "TILE", const std::string& note="") {
auto doc = fmt::format(
R"doc(
{2}
Args:
grad_tensor (ComplexTensor or ttnn.Tensor): the input gradient tensor.
input_tensor_a (ComplexTensor or ttnn.Tensor): the input tensor.
grad_tensor (ttnn.Tensor): the input gradient tensor.
input_tensor_a (ttnn.Tensor): the input tensor.
Keyword args:
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Expand All @@ -803,6 +850,21 @@ void bind_unary_backward_optional(
Returns:
List of ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- {4}
- 2, 3, 4
{5}
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
Expand All @@ -811,7 +873,10 @@ void bind_unary_backward_optional(
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);
description,
supported_dtype,
layout,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -899,13 +964,26 @@ void bind_unary_backward_prod_bw(py::module& module, const unary_backward_operat
input_tensor_a (ttnn.Tensor): the input tensor.
Keyword args:
all_dimensions (bool, optional): perform prod backward along all dimensions ,ignores dim param . Defaults to `True`.
dim (int, optional): Dimension to perform prod backward. Defaults to `0`.
all_dimensions (bool, optional): perform prod backward along all dimensions, ignores dim param. Defaults to `True`.
dim (int, optional): dimension to perform prod backward. Defaults to `0`.
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16
- TILE
- 4
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device=device)
Expand Down Expand Up @@ -985,7 +1063,7 @@ void bind_unary_backward_opt(py::module& module, const unary_backward_operation_

template <typename unary_backward_operation_t>
void bind_unary_backward(
py::module& module, const unary_backward_operation_t& operation, const std::string& description, const std::string& note = "") {
py::module& module, const unary_backward_operation_t& operation, const std::string& description, const std::string& supported_dtype = "") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -1012,7 +1090,7 @@ void bind_unary_backward(
operation.base_name(),
operation.python_fully_qualified_name(),
description,
note);
supported_dtype);

bind_registered_operation(
module,
Expand Down Expand Up @@ -1056,6 +1134,9 @@ void bind_unary_backward_gelu(
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (uint8, optional): command queue id. Defaults to `0`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{6}
Expand Down Expand Up @@ -1108,7 +1189,9 @@ void py_module(py::module& module) {
"max",
"Maximum value",
std::nullopt,
R"doc(Performs backward operations for clamp value on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`. Only one of 'min' or 'max' value can be None.)doc");
R"doc(Performs backward operations for clamp value on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc",
R"doc(Only one of `min` or `max` value can be `None`.)doc");

detail::bind_unary_backward_optional_float_params_with_default(
module,
Expand All @@ -1119,7 +1202,9 @@ void py_module(py::module& module) {
"max",
"Maximum value",
std::nullopt,
R"doc(Performs backward operations for clip on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`. Only one of 'min' or 'max' value can be None.)doc");
R"doc(Performs backward operations for clip on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc",
R"doc(Only one of `min` or `max` value can be `None`.)doc");

detail::bind_unary_backward_two_float_with_default(
module,
Expand Down Expand Up @@ -1349,12 +1434,13 @@ void py_module(py::module& module) {
detail::bind_unary_backward_optional(
module,
ttnn::exp_bw,
R"doc(Performs backward operations for exponential function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc");
R"doc(Performs backward operations for exponential function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_unary_backward_optional(
module,
ttnn::tanh_bw,
R"doc(Performs backward operations for Hyperbolic Tangent (Tanh) function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc");
R"doc(Performs backward operations for hyperbolic tangent (tanh) function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward_optional(
module,
Expand Down Expand Up @@ -1396,8 +1482,8 @@ void py_module(py::module& module) {
detail::bind_unary_backward_optional(
module,
ttnn::fill_bw,
R"doc(Performs backward operations for fill on :attr:`input_tensor` with given :attr:`grad_tensor`.
Returns an tensor like :attr:`grad_tensor` with sum of tensor values.)doc");
R"doc(Performs backward operations for fill on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc", R"doc(TILE, ROW MAJOR)doc");

detail::bind_unary_backward(
module,
Expand Down Expand Up @@ -1464,7 +1550,7 @@ void py_module(py::module& module) {
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
Expand Down Expand Up @@ -1494,6 +1580,8 @@ void py_module(py::module& module) {
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE, ROW_MAJOR | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
| BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");

Expand All @@ -1520,7 +1608,7 @@ void py_module(py::module& module) {
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
Expand Down Expand Up @@ -1655,7 +1743,8 @@ void py_module(py::module& module) {
module,
ttnn::rpow_bw,
R"doc(Performs backward operations for rpow on :attr:`input_tensor`, :attr:`exponent` with given :attr:`grad_tensor`.)doc",
"exponent","Exponent value");
"exponent","Exponent value",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_unary_backward(
module,
Expand Down Expand Up @@ -1969,7 +2058,8 @@ void py_module(py::module& module) {
module,
ttnn::div_no_nan_bw,
R"doc(Performs backward operations for div_no_nan on :attr:`input_tensor`, :attr:`scalar` with given :attr:`grad_tensor`.)doc",
"scalar","Denominator value");
"scalar","Denominator value",
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_unary_backward_op(
module,
Expand Down Expand Up @@ -2022,7 +2112,7 @@ void py_module(py::module& module) {
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
Expand Down Expand Up @@ -2072,7 +2162,7 @@ void py_module(py::module& module) {
detail::bind_unary_backward_float(
module,
ttnn::polygamma_bw,
R"doc(Performs backward operations for polygamma on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`scalar` with given :attr:`grad_tensor`.)doc",
R"doc(Performs backward operations for polygamma on :attr:`input_tensor` or :attr:`input_tensor_a`, :attr:`scalar` with given :attr:`grad_tensor`.)doc",
"n", "Order of polygamma function");
}

Expand Down

0 comments on commit e4c3dcc

Please sign in to comment.