Skip to content

Commit

Permalink
#14862: fp32 support in unary
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Nov 8, 2024
1 parent bd3b2d7 commit e11ae6b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
24 changes: 24 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_fp32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn

torch.set_printoptions(precision=10)

with ttnn.manage_device(device_id=0) as device:
x_torch = torch.tensor([[0.00001]], dtype=torch.float32)
y_torch = -x_torch

x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.neg(x_tt)

tt_out = ttnn.to_torch(y_tt)
print(x_torch, ttnn.to_torch(x_tt))
print(y_torch, ttnn.to_torch(y_tt))

status = torch.allclose(y_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)
print("pass: ", status)

assert status
5 changes: 2 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ inline Tensor unary_impl(
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt) {
DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(op_chain[0].params[1]) : input_tensor.get_dtype();
bool preserve_fp32_precision = (op_chain[0].op_type == UnaryOpType::TYPECAST) and (input_tensor.get_dtype() == DataType::FLOAT32);
bool preserve_fp32_precision = input_tensor.get_dtype() == DataType::FLOAT32;
bool fp32_dest_acc_en = preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to
// DST directly, fp32 is converted to fp16b
input_tensor.get_dtype() == DataType::INT32;

auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config());
return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor);
Expand Down

0 comments on commit e11ae6b

Please sign in to comment.