Skip to content

Commit

Permalink
#14782: Restructure doc for supported params
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw-anasuya committed Nov 7, 2024
1 parent 04bf687 commit d6a01f8
Showing 1 changed file with 88 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ void bind_unary_backward_rsqrt(
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (uint8, optional): command queue id. Defaults to `0`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{3}
Expand Down Expand Up @@ -296,7 +299,14 @@ void bind_unary_backward_op_overload_abs(
}

template <typename unary_backward_operation_t>
void bind_unary_backward_float(py::module& module, const unary_backward_operation_t& operation, const std::string& description, const std::string& parameter_name_a, const std::string& parameter_a_doc, const std::string& supported_dtypes = "") {
void bind_unary_backward_float(
py::module& module,
const unary_backward_operation_t& operation,
const std::string& description,
const std::string& parameter_name_a,
const std::string& parameter_a_doc,
const std::string& supported_dtype = "BFLOAT16",
const std::string& note="") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -313,7 +323,19 @@ void bind_unary_backward_float(py::module& module, const unary_backward_operatio
List of ttnn.Tensor: the output tensor.
Note:
{5}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {5}
- TILE
- 2, 3, 4
{6}
Example:
Expand All @@ -326,7 +348,8 @@ void bind_unary_backward_float(py::module& module, const unary_backward_operatio
description,
parameter_name_a,
parameter_a_doc,
supported_dtypes);
supported_dtype,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -490,7 +513,8 @@ void bind_unary_backward_optional_float_params_with_default(
const std::string& parameter_b_doc,
std::optional<float> parameter_b_value,
const std::string& description,
const std::string& suported_dtype = "") {
const std::string& suported_dtype = "BFLOAT16",
const std::string& note = "") {
auto doc = fmt::format(
R"doc(
{8}
Expand All @@ -508,7 +532,20 @@ void bind_unary_backward_optional_float_params_with_default(
List of ttnn.Tensor: the output tensor.
Note:
{9}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {9}
- TILE
- 2, 3, 4
{10}
Example:
Expand All @@ -525,7 +562,8 @@ void bind_unary_backward_optional_float_params_with_default(
parameter_b_doc,
parameter_b_value,
description,
suported_dtype);
suported_dtype,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -795,7 +833,7 @@ void bind_unary_backward_shape(

template <typename unary_backward_operation_t>
void bind_unary_backward_optional(
py::module& module, const unary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") {
py::module& module, const unary_backward_operation_t& operation, const std::string_view description, const std::string& supported_dtype = "BFLOAT16", const std::string& layout = "TILE", const std::string& note="") {
auto doc = fmt::format(
R"doc(
{2}
Expand All @@ -813,7 +851,19 @@ void bind_unary_backward_optional(
List of ttnn.Tensor: the output tensor.
Note:
{3}
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - {3}
- {4}
- 2, 3, 4
{5}
Example:
Expand All @@ -824,7 +874,9 @@ void bind_unary_backward_optional(
operation.base_name(),
operation.python_fully_qualified_name(),
description,
supported_dtype);
supported_dtype,
layout,
note);

bind_registered_operation(
module,
Expand Down Expand Up @@ -922,11 +974,15 @@ void bind_unary_backward_prod_bw(py::module& module, const unary_backward_operat
Note:
Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 4 |
+----------------------------+---------------------------------+-------------------+
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16
- TILE
- 4
Example:
Expand Down Expand Up @@ -1078,6 +1134,9 @@ void bind_unary_backward_gelu(
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (uint8, optional): command queue id. Defaults to `0`.
Returns:
List of ttnn.Tensor: the output tensor.
Note:
{6}
Expand Down Expand Up @@ -1130,16 +1189,9 @@ void py_module(py::module& module) {
"max",
"Maximum value",
std::nullopt,
R"doc(Performs backward operations for clamp value on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`. Only one of 'min' or 'max' value can be `None`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for clamp value on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc",
R"doc(Only one of `min` or `max` value can be `None`.)doc");

detail::bind_unary_backward_optional_float_params_with_default(
module,
Expand All @@ -1150,16 +1202,9 @@ void py_module(py::module& module) {
"max",
"Maximum value",
std::nullopt,
R"doc(Performs backward operations for clip on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`. Only one of 'min' or 'max' value can be None.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for clip on :attr:`input_tensor`, :attr:`min`, :attr:`max` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc",
R"doc(Only one of `min` or `max` value can be `None`.)doc");

detail::bind_unary_backward_two_float_with_default(
module,
Expand Down Expand Up @@ -1390,43 +1435,17 @@ void py_module(py::module& module) {
module,
ttnn::exp_bw,
R"doc(Performs backward operations for exponential function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_unary_backward_optional(
module,
ttnn::tanh_bw,
R"doc(Performs backward operations for hyperbolic tangent (tanh) function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for hyperbolic tangent (tanh) function on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward_optional(
module,
ttnn::sqrt_bw,
R"doc(Performs backward operations for square-root on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for square-root on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc");

detail::bind_unary_backward(
module,
Expand Down Expand Up @@ -1463,17 +1482,8 @@ void py_module(py::module& module) {
detail::bind_unary_backward_optional(
module,
ttnn::fill_bw,
R"doc(Performs backward operations for fill on :attr:`input_tensor` with given :attr:`grad_tensor`.
Returns an tensor like :attr:`grad_tensor` with sum of tensor values.)doc",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE, ROW MAJOR | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for fill on :attr:`input_tensor` with given :attr:`grad_tensor`.)doc",
R"doc(BFLOAT16)doc", R"doc(TILE, ROW MAJOR)doc");

detail::bind_unary_backward(
module,
Expand Down Expand Up @@ -1734,15 +1744,7 @@ void py_module(py::module& module) {
ttnn::rpow_bw,
R"doc(Performs backward operations for rpow on :attr:`input_tensor`, :attr:`exponent` with given :attr:`grad_tensor`.)doc",
"exponent","Exponent value",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_unary_backward(
module,
Expand Down Expand Up @@ -2057,15 +2059,7 @@ void py_module(py::module& module) {
ttnn::div_no_nan_bw,
R"doc(Performs backward operations for div_no_nan on :attr:`input_tensor`, :attr:`scalar` with given :attr:`grad_tensor`.)doc",
"scalar","Denominator value",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(BFLOAT16, BFLOAT8_B)doc");

detail::bind_unary_backward_op(
module,
Expand Down Expand Up @@ -2168,17 +2162,8 @@ void py_module(py::module& module) {
detail::bind_unary_backward_float(
module,
ttnn::polygamma_bw,
R"doc(Performs backward operations for polygamma on :attr:`input_tensor` or attr:`input_tensor_a`, attr:`scalar` with given :attr:`grad_tensor`.)doc",
"n", "Order of polygamma function",
R"doc(Supported dtypes, layouts, and ranks:
+----------------------------+---------------------------------+-------------------+
| Dtypes | Layouts | Ranks |
+----------------------------+---------------------------------+-------------------+
| BFLOAT16 | TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+
)doc");
R"doc(Performs backward operations for polygamma on :attr:`input_tensor` or :attr:`input_tensor_a`, :attr:`scalar` with given :attr:`grad_tensor`.)doc",
"n", "Order of polygamma function");
}

} // namespace unary_backward
Expand Down

0 comments on commit d6a01f8

Please sign in to comment.