Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests for linear op #965

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions forge/test/operators/pytorch/nn/test_convtranspose2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
from functools import reduce
import random
import pytest

from typing import List, Dict, Type, Optional, Any
from loguru import logger

import torch
import forge
import forge.op

from forge.verify.config import VerifyConfig
from forge.verify.value_checkers import AllCloseValueChecker

from test.operators.utils import InputSourceFlags, VerifyUtils, ValueRanges
from test.operators.utils import InputSource
from test.operators.utils import TestVector
from test.operators.utils import TestPlan
from test.operators.utils import FailingReasons
from test.operators.utils.compat import TestDevice
from test.operators.utils.compat import TestTensorsUtils
from test.operators.utils import TestCollection
from test.operators.utils import TestCollectionCommon


class ModelFromAnotherOp(torch.nn.Module):

model_name = "model_op_src_from_another_op"

def __init__(self, operator, opname, shape, kwargs):
super(ModelFromAnotherOp, self).__init__()
self.testname = "ConvTranspose2d_pytorch_operator_" + opname + "_test_op_src_from_another_op"
self.operator = operator
self.opname = opname
self.shape = shape
self.kwargs = {}

self.ct1 = self.operator(**self.kwargs)

def forward(self, x: torch.Tensor):
# we use Add operator to create one operands which is input for the ConvTranspose2d operator
add = torch.add(x, x)
output = self.ct1(add)
return output


class ModelDirect(torch.nn.Module):

model_name = "model_op_src_from_host"

def __init__(self, operator, opname, shape, kwargs):
super(ModelDirect, self).__init__()
self.testname = "ConvTranspose2d_pytorch_operator_" + opname + "_test_op_src_from_host"
self.operator = operator
self.opname = opname
self.shape = shape
self.kwargs = {}

self.ct1 = self.operator(**self.kwargs)

def forward(self, x: torch.Tensor):
output = self.ct1(x)
return output


class ModelConstEvalPass(torch.nn.Module):

model_name = "model_op_src_const_eval_pass"

def __init__(self, operator, opname, shape, kwargs, dtype):
super(ModelConstEvalPass, self).__init__()
self.testname = "ConvTranspose2d_pytorch_operator_" + opname + "_test_op_src_const_eval_pass"
self.operator = operator
self.opname = opname
self.shape = shape
self.kwargs = {}

self.constant = torch.rand(self.shape, dtype=dtype)
self.ct1 = self.operator(**self.kwargs)

def forward(self, x: torch.Tensor):
v1 = self.ct1(self.constant)
# v2 = torch.add(x, x)
v2 = self.ct1(x)
# add consume inputs
add = torch.add(v1, v2)
return add


class TestVerification:

MODEL_TYPES = {
InputSource.FROM_ANOTHER_OP: ModelFromAnotherOp,
InputSource.FROM_HOST: ModelDirect,
InputSource.FROM_DRAM_QUEUE: ModelDirect,
InputSource.CONST_EVAL_PASS: ModelConstEvalPass,
}

@classmethod
def verify(
cls,
test_device: TestDevice,
test_vector: TestVector,
input_params: List[Dict] = [],
number_of_operands: int = 1,
warm_reset: bool = False,
):
"""Common verification function for all tests"""

input_source_flag: InputSourceFlags = None
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,):
input_source_flag = InputSourceFlags.FROM_DRAM

operator = getattr(torch.nn, test_vector.operator)

kwargs = test_vector.kwargs if test_vector.kwargs else {}

model_type = cls.MODEL_TYPES[test_vector.input_source]
if test_vector.input_source == InputSource.CONST_EVAL_PASS:
pytorch_model = model_type(
operator=operator,
opname=test_vector.operator,
shape=test_vector.input_shape,
kwargs=kwargs,
dtype=TestTensorsUtils.get_dtype_for_df(test_vector.dev_data_format),
)
else:
pytorch_model = model_type(
operator=operator,
opname=test_vector.operator,
shape=test_vector.input_shape,
kwargs=kwargs,
)

input_shapes = tuple([test_vector.input_shape for _ in range(number_of_operands)])
logger.trace(f"***input_shapes: {input_shapes}")

VerifyUtils.verify(
model=pytorch_model,
test_device=test_device,
input_shapes=input_shapes,
input_params=input_params,
input_source_flag=input_source_flag,
dev_data_format=test_vector.dev_data_format,
math_fidelity=test_vector.math_fidelity,
pcc=test_vector.pcc,
warm_reset=warm_reset,
deprecated_verification=False,
verify_config=VerifyConfig(value_checker=AllCloseValueChecker(rtol=1e-2, atol=1e-2)),
value_range=ValueRanges.SMALL,
)


class TestParamsData:

__test__ = False # Avoid collecting TestParamsData as a pytest test

test_plan: TestPlan = None

@classmethod
def generate_kwargs(cls, test_vector: TestVector):
kwarg_list = []
rng = random.Random(sum(test_vector.input_shape))
# if len(test_vector.input_shape) == 4:
# N = test_vector.input_shape[-4]
N = test_vector.input_shape[0]
C_in = test_vector.input_shape[-3]
H_in = test_vector.input_shape[-2]
W_in = test_vector.input_shape[-1]

in_channels = C_in
out_channels = rng.randint(1, C_in + 100) # it can be less, equal or greater than in_channels

dilation = rng.randint(1, 3)

# Two cases for kernel size
k_maxh = (H_in if H_in > 3 else 3) / dilation + 1
k_maxw = (W_in if W_in > 3 else 3) / dilation + 1
kernel_size_option1 = rng.randint(3, k_maxh)
kernel_size_option2 = rng.randint(3, k_maxw)
# 1. kernel is equal to integer
kernel_size = random.choice(kernel_size_option1, kernel_size_option2)
# make it odd number
kernel_size = kernel_size if kernel_size % 2 != 0 else kernel_size + 1
# 2. kernel is equal to tuple
kernel_size = (kernel_size_option1, kernel_size_option2)
# assert that kernel value will fit in the input shape
# if isinstance(kernel_size, tuple):
# assert dilation * (kernel_size[0] - 1) < H_in, "Invalid kernel height!"
# assert dilation * (kernel_size[1] - 1) < W_in, "Invalid kernel width!"
# else:
# assert dilation * (kernel_size - 1) < H_in, "Invalid height"
# assert dilation * (kernel_size - 1) < W_in, "Invalid width"

kwarg_list.append(
{"in_channels": in_channels, "out_channels": out_channels, "kernel_size": kernel_size, "dilation": dilation}
)
return kwarg_list


class TestCollectionData:

__test__ = False # Avoid collecting TestCollectionData as a pytest test

all = TestCollection(
operators=[
"ConvTranspose2d", # 00
],
input_sources=TestCollectionCommon.all.input_sources,
# only 4D input tensors are supported
input_shapes=[input_shape for input_shape in TestCollectionCommon.all.input_shapes if len(input_shape) == 4],
dev_data_formats=TestCollectionCommon.all.dev_data_formats,
math_fidelities=TestCollectionCommon.all.math_fidelities,
)

single = TestCollection(
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
dev_data_formats=TestCollectionCommon.single.dev_data_formats,
math_fidelities=TestCollectionCommon.single.math_fidelities,
)


TestParamsData.test_plan = TestPlan(
verify=lambda test_device, test_vector: TestVerification.verify(
test_device,
test_vector,
),
collections=[
# Test plan:
# 2. Operand source(s):
# 3. Operand shapes type(s):
# 4. Operand / output size of dimensions
TestCollection(
operators=TestCollectionData.all.operators,
input_sources=TestCollectionData.single.input_sources,
input_shapes=TestCollectionData.all.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
),
# # Test plan:
# # 5. Data format
# TestCollection(
# operators=TestCollectionData.all.operators,
# input_sources=TestCollectionData.single.input_sources,
# input_shapes=TestCollectionData.single.input_shapes,
# kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
# dev_data_formats=TestCollectionCommon.float.dev_data_formats,
# math_fidelities=TestCollectionData.single.math_fidelities,
# ),
# # Test plan:
# # 6. Math fidelity
# TestCollection(
# operators=TestCollectionData.all.operators,
# input_sources=TestCollectionData.single.input_sources,
# input_shapes=TestCollectionData.single.input_shapes,
# kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
# dev_data_formats=TestCollectionData.single.dev_data_formats,
# math_fidelities=TestCollectionData.all.math_fidelities,
# ),
],
failing_rules=[],
)


def get_test_plans() -> List[TestPlan]:
return [
TestParamsData.test_plan,
]
Loading
Loading