Skip to content

Commit

Permalink
#14379: Support pad_value for ttnn.from_torch (#14380)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrshaw01 committed Nov 9, 2024
1 parent 37480d8 commit b75f637
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 33 deletions.
10 changes: 3 additions & 7 deletions tests/ttnn/unit_tests/operations/test_moreh_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,14 @@

import ttnn
from models.utility_functions import comp_allclose_and_pcc, is_grayskull

from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
compute_kernel_options,
create_ttnn_tilized_tensor,
get_compute_kernel_options,
)


def create_ttnn_tilized_tensor(torch_tensor, device, dtype):
return ttnn.from_torch(torch_tensor, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT)


def get_tensors(
input_shape,
mat2_shape,
Expand Down
12 changes: 4 additions & 8 deletions tests/ttnn/unit_tests/operations/test_moreh_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,17 @@

import ttnn
from models.utility_functions import comp_allclose

from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
TILE_HEIGHT,
TILE_WIDTH,
check_dim,
compute_kernel_ids,
compute_kernel_options,
create_ttnn_tilized_tensor,
get_compute_kernel_options,
)


def create_ttnn_tilized_tensor(torch_tensor, device, dtype):
return ttnn.from_torch(torch_tensor, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT)


def run_moreh_mean(
input_shape_dim,
device,
Expand Down
14 changes: 5 additions & 9 deletions tests/ttnn/unit_tests/operations/test_moreh_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@

import pytest
import torch
from loguru import logger

import ttnn
from models.utility_functions import comp_allclose, is_wormhole_b0
from loguru import logger

from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
check_dim,
compute_kernel_ids,
compute_kernel_options,
compute_output_shape,
check_dim,
create_ttnn_tilized_tensor,
get_compute_kernel_options,
)


def create_ttnn_tilized_tensor(torch_tensor, device, dtype):
return ttnn.from_torch(torch_tensor, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT)


def make_torch_tensors(input_shape, dim, keepdim=False, *, dtype=torch.float32):
"""
Creates random tensors for input and gradient output based on the input shape and dimension.
Expand Down
11 changes: 8 additions & 3 deletions tests/ttnn/unit_tests/operations/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch
from models.utility_functions import is_wormhole_b0
import copy

import pytest
import torch

import ttnn
from models.utility_functions import is_wormhole_b0

TILE_HEIGHT = 32
TILE_WIDTH = 32
Expand Down Expand Up @@ -197,3 +198,7 @@ def get_ttnn_torch_dtype(ttnn_dtype: ttnn.DataType) -> torch.dtype:
ttnn.int32: torch.int32,
}
return dtype_map.get(ttnn_dtype, None)


def create_ttnn_tilized_tensor(torch_tensor, device, dtype, pad_value=float("nan")):
return ttnn.from_torch(torch_tensor, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, pad_value=pad_value)
19 changes: 13 additions & 6 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import math
import pathlib
from typing import Union, Tuple, Optional, Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Tuple, Union

from loguru import logger
import torch

import ttnn.decorators
from loguru import logger

import ttnn
import ttnn.decorators


def _golden_function(input_tensor: ttnn.Tensor, slices):
Expand Down Expand Up @@ -154,6 +153,7 @@ def from_torch(
dtype: Optional[ttnn.DataType] = None,
*,
tile: Optional[ttnn.Tile] = None,
pad_value: Optional[float] = None,
layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT,
device: Optional[ttnn.Device] = None,
memory_config: Optional[ttnn.MemoryConfig] = None,
Expand All @@ -171,6 +171,7 @@ def from_torch(
Keyword Args:
tile (ttnn.Tile, optional): the desired tiling configuration for the tensor. Defaults to `None`.
pad_value (float, optional): the desired padding value for tiling. Only used if `layout` is `TILE_LAYOUT`. Defaults to `None`.
layout (ttnn.Layout, optional): the desired `ttnn` layout. Defaults to `ttnn.ROW_MAJOR_LAYOUT`.
device (ttnn.Device, optional): the desired `ttnn` device. Defaults to `None`.
memory_config (ttnn.MemoryConfig, optional): The desired `ttnn` memory configuration. Defaults to `None`.
Expand All @@ -193,14 +194,18 @@ def from_torch(
if layout != ttnn.TILE_LAYOUT:
raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!")
# Tilize tensor
tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile)
tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value)
shape_with_padding = tensor.shape
tensor = tensor.reshape(tensor.shape.with_tile_padding())
tensor = ttnn.to_torch(tensor)

if memory_config is not None:
if device is None:
raise RuntimeError("device must be specified when memory_config is specified")
raise RuntimeError("ttnn.from_torch: device must be specified when memory_config is specified")

if pad_value is not None:
if layout != ttnn.TILE_LAYOUT:
raise RuntimeError("ttnn.from_torch: layout must be TILE_LAYOUT when pad_value is specified")

if mesh_mapper:
shards = mesh_mapper.map(tensor)
Expand All @@ -215,6 +220,8 @@ def from_torch(
tensor = ttnn.Tensor(tensor, dtype)

if layout is not None and not (dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b):
if pad_value is not None:
tensor = tensor.pad_to_tile(pad_value)
tensor = ttnn.to_layout(tensor, layout, device=device)

if device is not None:
Expand Down

0 comments on commit b75f637

Please sign in to comment.