Skip to content

Commit

Permalink
Add prelu support for blackhole
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 21, 2024
1 parent 60ad333 commit 80d07e3
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 6 deletions.
1 change: 0 additions & 1 deletion tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from functools import partial

import torch
import random
import ttnn

from tests.sweep_framework.sweep_utils.utils import gen_shapes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import (
data_gen_with_range,
data_gen_with_range_int,
data_gen_with_val,
compare_pcc,
compare_equal,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_prelu.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_prelu(const uint value) {

// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat init = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
a = a * init;
}
v_endif;
dst_reg[0] = a;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_prelu.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::prelu, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_prelu(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_prelu<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ enum SfpuType {
ceil,
unused,
cumsum,
fill
fill,
prelu,
};
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

Expand All @@ -23,7 +23,8 @@ inline void calculate_prelu(uint value) {
c_value.u = value;
vFloat init = c_value.f;

for (int d = 0; d < 8; d++)
#pragma GCC unroll 8
for (int d = 0; d < ITERATIONS; d++)
{
vFloat a = dst_reg[0];
v_if(a < 0.0f) {
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/include/compute_kernel_api/eltwise_unary/prelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ckernel {
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | param0 | The value the output is if the input is greater than 0 | uint32_t | | True |
* | param0 | Constant value that is being multiplied if the input is lesser than 0 | uint32_t | | True |
*/
ALWI void prelu_tile(uint32_t idst, uint32_t param0) {
MATH((llk_math_eltwise_unary_sfpu_prelu<APPROX>(idst, param0)));
Expand Down

0 comments on commit 80d07e3

Please sign in to comment.