Skip to content

Commit

Permalink
#14982: Update threshold logic (#15362)
Browse files Browse the repository at this point in the history
### Ticket
#14982

### Problem description
Provide context for the problem.

### What's changed

- Updated threshold logic to handle cases when input_tensor=threshold
value
- Updated with supported data type and layout

Tests : 

- `pytest
tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_threshold.py`
:
<img width="1101" alt="Screenshot 2024-11-22 at 6 09 03 PM"
src="https://github.com/user-attachments/assets/80b5820f-8e4d-4a8b-9914-d3df21f72f90">

- `pytest
tests/ttnn/unit_tests/operations/eltwise/test_activation.py::test_threshold`
:
<img width="1087" alt="Screenshot 2024-11-22 at 6 09 39 PM"
src="https://github.com/user-attachments/assets/61f30c51-5d78-49d2-af08-e7c764a330a0">

- `pytest
tests/ttnn/unit_tests/operations/eltwise/test_composite.py::test_unary_composite_threshold_ttnn`
:
<img width="1088" alt="Screenshot 2024-11-22 at 6 10 21 PM"
src="https://github.com/user-attachments/assets/a2ca063e-65fb-4a4e-a9cc-6348eb3c95c0">

- `python tests/ttnn/sweep_tests/run_sweeps.py --include threshold.py` -
passed
- `pytest
tests/ttnn/unit_tests/operations/eltwise/test_composite.py::test_threshold_example`
:
<img width="1092" alt="Screenshot 2024-11-22 at 6 11 40 PM"
src="https://github.com/user-attachments/assets/56eafb45-8011-440c-8958-f212869b0634">

### Checklist
- [ ] Post commit CI passes

### Doc screenshot
<img width="963" alt="Screenshot 2024-11-22 at 11 06 48 PM"
src="https://github.com/user-attachments/assets/1be96ac7-dc1b-434a-ad1d-ebe95edd7b51">
  • Loading branch information
VirdhatchaniKN authored Nov 24, 2024
1 parent 077f71e commit 388d56e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -543,11 +543,9 @@ Tensor _selu(const Tensor& x, const float scale, const float alpha, const std::o
}

// threshold(a,t,v) = (a <= t)*v + (a > t)*a
Tensor _threshold(const Tensor& input_tensor, float threshold, float value, const std::optional<MemoryConfig>& output_mem_config) {
Tensor t0 = ttnn::subtract(input_tensor, threshold, std::nullopt, output_mem_config);
Tensor t1 = ttnn::multiply(ttnn::lez(t0), value, std::nullopt, output_mem_config);
Tensor t2 = ttnn::multiply(ttnn::gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config);
return ttnn::add(t1, t2, std::nullopt, output_mem_config);
Tensor ExecuteUnaryCompositeThreshold::invoke(const Tensor& input_tensor, float threshold, float value, const std::optional<MemoryConfig>& output_mem_config) {
Tensor sub_result = ttnn::subtract(input_tensor, threshold, std::nullopt, output_mem_config);
return ttnn::where(ttnn::lez(sub_result), value, input_tensor, output_mem_config);
}

std::vector<Tensor> split_tensor_for_glu(const Tensor& input_a, int32_t dim, const std::optional<MemoryConfig>& output_mem_config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ enum class UnaryCompositeOpType {
HARDSIGMOID,
HARDTANH,
SELU,
THRESHOLD,
GLU,
REGLU,
GEGLU,
Expand Down Expand Up @@ -82,7 +81,6 @@ Tensor _hardswish(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, c
Tensor _hardsigmoid(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _hardtanh(const Tensor&, float min = -1, float max = 1, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _selu(const Tensor&, float scale = 1.0507, float alpha = 1.67326, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor _threshold(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _glu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _reglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _geglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Expand Down Expand Up @@ -267,13 +265,6 @@ struct OpHandler<UnaryCompositeOpType::SELU> {
}
};

template <>
struct OpHandler<UnaryCompositeOpType::THRESHOLD> {
static Tensor handle(const Tensor& t1, float threshold, float value, const std::optional<MemoryConfig>& mem_cfg ) {
return _threshold(t1, threshold, value, mem_cfg);
}
};

//glu (geglu, reglu, swiglu, glu) varinats are supported only for last dimension.
template <>
struct OpHandler<UnaryCompositeOpType::GLU> {
Expand Down
11 changes: 9 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ struct ExecuteUnaryCompositeClamp {
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryCompositeThreshold {
static Tensor invoke(
const Tensor &input_tensor,
float threshold,
float value,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryCompositeClip {
static Tensor invoke(
const Tensor &input_tensor,
Expand Down Expand Up @@ -305,8 +313,7 @@ constexpr auto selu = ttnn::register_operation_with_auto_launch_op<
operations::unary::ExecuteUnaryCompositeOpWithFloats<operations::unary::UnaryCompositeOpType::SELU>>();
constexpr auto threshold = ttnn::register_operation_with_auto_launch_op<
"ttnn::threshold",
operations::unary::ExecuteUnaryCompositeOpWithFloats<operations::unary::UnaryCompositeOpType::THRESHOLD>>();

operations::unary::ExecuteUnaryCompositeThreshold>();
constexpr auto glu = ttnn::register_operation_with_auto_launch_op<
"ttnn::glu",
operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::GLU>>();
Expand Down
23 changes: 19 additions & 4 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ void bind_unary_composite_int(py::module& module, const unary_operation_t& opera

//OpHandler_two_float_with_default
template <typename unary_operation_t>
void bind_unary_composite_floats(
void bind_unary_composite_threshold(
py::module& module,
const unary_operation_t& operation,
const std::string& parameter_name_a,
Expand All @@ -1342,8 +1342,23 @@ void bind_unary_composite_floats(
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16
- TILE
- 2, 3, 4
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> {2} = 1.0
>>> {4} = 10.0
>>> output = {1}(tensor, {2}, {4})
)doc",
operation.base_name(),
Expand Down Expand Up @@ -1975,11 +1990,11 @@ void py_module(py::module& module) {
ttnn::selu,
"scale", "Scale value", 1.0507,
"alpha", "Alpha value", 1.67326);
detail::bind_unary_composite_floats(
detail::bind_unary_composite_threshold(
module,
ttnn::threshold,
"threshold", "Threshold value",
"value", "Value value",
"value", "Replacing value",
R"doc(Performs threshold function on :attr:`input_tensor`, :attr:`threshold`, :attr:`value`.)doc");
detail::bind_unary_composite_int_with_default(
module,
Expand Down

0 comments on commit 388d56e

Please sign in to comment.