Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'gelu' and 'leaky_relu' operators test_plan and failing_rules #1203

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions forge/forge/op_repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange
from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository
from .datatypes import ShapeCalculationContext
from .pytorch_operators import pytorch_operator_repository

__ALL__ = [
"OperandNumInt",
Expand All @@ -26,4 +27,5 @@
"OperatorDefinition",
"OperatorRepository",
"ShapeCalculationContext",
"pytorch_operator_repository",
]
99 changes: 99 additions & 0 deletions forge/forge/op_repo/pytorch_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

# PyTorch repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [
OperatorDefinition(
"linear",
"torch.nn.Linear",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_features", int, 10, 50),
OperatorParamNumber("out_features", int, 10, 50),
],
),
OperatorDefinition(
"conv2d",
"torch.nn.Conv2d",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_channels", int, 10, 50),
OperatorParamNumber("out_channels", int, 10, 50),
OperatorParamNumber("kernel_size", int, 3, 3),
OperatorParamNumber("stride", int, 1, 1),
OperatorParamNumber("padding", int, 1, 1),
],
),
OperatorDefinition("relu", "torch.relu", 1),
OperatorDefinition("sqrt", "torch.sqrt", 1),
OperatorDefinition("reciprocal", "torch.reciprocal", 1),
OperatorDefinition("sigmoid", "torch.sigmoid", 1),
OperatorDefinition("abs", "torch.abs", 1),
OperatorDefinition("cos", "torch.cos", 1),
OperatorDefinition("exp", "torch.exp", 1),
OperatorDefinition("neg", "torch.neg", 1),
OperatorDefinition("rsqrt", "torch.rsqrt", 1),
OperatorDefinition("sin", "torch.sin", 1),
OperatorDefinition("square", "torch.square", 1),
OperatorDefinition("pow", "torch.pow", 1),
OperatorDefinition("clamp", "torch.clamp", 1),
OperatorDefinition("log", "torch.log", 1),
OperatorDefinition("log1p", "torch.log1p", 1),
OperatorDefinition("gelu", "torch.nn.functional.gelu", 1),
OperatorDefinition("leaky_relu", "torch.nn.functional.leaky_relu", 1),
OperatorDefinition("tanh", "torch.tanh", 1),
Comment on lines +37 to +54
Copy link
Contributor

@vbrkicTT vbrkicTT Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not implemented unary operators should be defined as well here. Otherwise tests for them are failing with IndexError: list index out of range instead of as xfail Operator not implemented.

The PR should depend on #1239

# OperatorDefinition("add", "torch.add", 1),
OperatorDefinition("add", "torch.add", 2),
OperatorDefinition("sub", "torch.sub", 2),
OperatorDefinition("mul", "torch.mul", 2),
OperatorDefinition("div", "torch.div", 2),
OperatorDefinition("ge", "torch.ge", 2),
# Non-linear activation functions
# HARDTANH = OperatorDefinition("hardtanh", 1)
# HARDWISH = OperatorDefinition("hardwish", 1)
# RELU6 = OperatorDefinition("relu6", 1)
# ELU = OperatorDefinition("elu", 1)
# SELU = OperatorDefinition("selu", 1)
# CELU = OperatorDefinition("celu", 1)
# LEACKY_RELU = OperatorDefinition("leaky_relu", 1)
# PRELU = OperatorDefinition("prelu", 1)
# RRELU = OperatorDefinition("rrelu", 1)
# GLU = OperatorDefinition("glu", 1)
# GELU = OperatorDefinition("gelu", 1)
# LOGSIGMOID = OperatorDefinition("logsigmoid", 1)
# HARDSHRINK = OperatorDefinition("hardshrink", 1)
# TANHSHRINK = OperatorDefinition("tanhshrink", 1)
# SOFTSIGN = OperatorDefinition("softsign", 1)
# SOFTPLUS = OperatorDefinition("softplus", 1)
# SOFTMIN = OperatorDefinition("softmin", 1)
# SOFTMAX = OperatorDefinition("softmax", 1)
# SOFTSHRINK = OperatorDefinition("softshrink", 1)
# GUMBEL_SOFTMAX = OperatorDefinition("gumbel_softmax", 1)
# LOG_SOFTMAX = OperatorDefinition("log_softmax", 1)
# TANH = OperatorDefinition("tanh", 1)
# SIGMOID = OperatorDefinition("sigmoid", 1)
# HARDSIGMOID = OperatorDefinition("hardsigmoid", 1)
# SILU = OperatorDefinition("silu", 1)
# MISH = OperatorDefinition("mish", 1)
# BATCH_NORM = OperatorDefinition("batch_norm", 1)
# GROUP_NORM = OperatorDefinition("group_norm", 1)
# INSTANCE_NORM = OperatorDefinition("instance_norm", 1)
# LAYER_NORM = OperatorDefinition("layer_norm", 1)
# LOCAL_RESPONSE_NORM = OperatorDefinition("local_response_norm", 1)
# NORMALIZE = OperatorDefinition("normalize", 1)
OperatorDefinition("matmul", "torch.matmul", 2),
OperatorDefinition("eltwise", "torch.add", 2),
]


pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS])
86 changes: 83 additions & 3 deletions forge/test/operators/pytorch/eltwise_unary/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@
# (/) Reuse inputs for selected operators


import torch

from typing import List, Dict
from loguru import logger
from forge import MathFidelity, DataFormat

from test.operators.utils import InputSourceFlags, VerifyUtils
from test.operators.utils import InputSource
from test.operators.utils import PytorchUtils
from test.operators.utils import TestVector
from test.operators.utils import TestPlan
from test.operators.utils import FailingReasons
Expand Down Expand Up @@ -90,7 +89,8 @@ def verify(
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,):
input_source_flag = InputSourceFlags.FROM_DRAM

operator = getattr(torch, test_vector.operator)
module = PytorchUtils.get_pytorch_module(test_vector.operator)
operator = getattr(module, test_vector.operator)

kwargs = test_vector.kwargs if test_vector.kwargs else {}

Expand Down Expand Up @@ -125,6 +125,7 @@ class TestParamsData:
__test__ = False

test_plan_implemented: TestPlan = None
test_plan_implemented_float: TestPlan = None
test_plan_not_implemented: TestPlan = None

no_kwargs = [
Expand All @@ -144,12 +145,27 @@ class TestParamsData:
{"exponent": 10.0},
]

kwargs_gelu = [
{"approximate": "tanh"},
{},
]

kwargs_leaky_relu = [
{"negative_slope": 0.01, "inplace": True},
{"negative_slope": 0.1, "inplace": False},
{},
]

@classmethod
def generate_kwargs(cls, test_vector: TestVector):
if test_vector.operator in ("clamp",):
return cls.kwargs_clamp
if test_vector.operator in ("pow",):
return cls.kwargs_pow
if test_vector.operator in ("gelu",):
return cls.kwargs_gelu
if test_vector.operator in ("leaky_relu",):
return cls.kwargs_leaky_relu
return cls.no_kwargs


Expand Down Expand Up @@ -179,6 +195,12 @@ class TestCollectionData:
"log1p",
],
)
implemented_float = TestCollection(
operators=[
"gelu",
"leaky_relu",
],
)
not_implemented = TestCollection(
operators=[
"acos",
Expand Down Expand Up @@ -691,6 +713,63 @@ class TestCollectionData:
)


TestParamsData.test_plan_implemented_float = TestPlan(
verify=lambda test_device, test_vector: TestVerification.verify(
test_device,
test_vector,
),
collections=[
# Test gelu, leaky_relu operators collection:
TestCollection(
operators=TestCollectionData.implemented_float.operators,
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
),
# Test gelu, leaky_relu data formats collection:
TestCollection(
operators=TestCollectionData.implemented_float.operators,
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
dev_data_formats=[
item
for item in TestCollectionCommon.float.dev_data_formats
if item not in TestCollectionCommon.single.dev_data_formats
],
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
# Test gelu, leaky_relu math fidelities collection:
TestCollection(
operators=TestCollectionData.implemented_float.operators,
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
dev_data_formats=TestCollectionCommon.single.dev_data_formats,
math_fidelities=TestCollectionCommon.all.math_fidelities,
),
],
failing_rules=[
TestCollection(
operators=["gelu"],
input_shapes=[(1, 1)],
kwargs=[
{"approximate": "tanh"},
{},
],
failing_reason=FailingReasons.DATA_MISMATCH,
),
TestCollection(
operators=["leaky_relu"],
input_sources=[InputSource.CONST_EVAL_PASS],
input_shapes=[(1, 1)],
kwargs=[{"negative_slope": 0.01, "inplace": True}],
failing_reason=FailingReasons.DATA_MISMATCH,
),
],
)


TestParamsData.test_plan_not_implemented = TestPlan(
verify=lambda test_device, test_vector: TestVerification.verify(
test_device,
Expand Down Expand Up @@ -718,5 +797,6 @@ class TestCollectionData:
def get_test_plans() -> List[TestPlan]:
return [
TestParamsData.test_plan_implemented,
TestParamsData.test_plan_implemented_float,
TestParamsData.test_plan_not_implemented,
]
2 changes: 2 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .utils import LoggerUtils
from .utils import RateLimiter
from .utils import FrameworkModelType
from .utils import PytorchUtils
from .features import TestFeaturesConfiguration
from .plan import InputSource
from .plan import TestVector
Expand Down Expand Up @@ -47,6 +48,7 @@
"RateLimiter",
"TestFeaturesConfiguration",
"FrameworkModelType",
"PytorchUtils",
"InputSource",
"TestVector",
"TestCollection",
Expand Down
17 changes: 17 additions & 0 deletions forge/test/operators/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from forge import ForgeModule, Module, DepricatedVerifyConfig
from forge.op_repo import TensorShape
from forge.op_repo.pytorch_operators import pytorch_operator_repository
from forge.verify import TestKind # , verify_module
from forge._C import MathFidelity

Expand Down Expand Up @@ -315,3 +316,19 @@ def limit_info(self) -> str:
return f"{self.current_value} <= {self.current_limit}"
else:
return f"{self.current_value} > {self.current_limit}"


class PytorchUtils:
"""Utility functions for PyTorch operators"""

@staticmethod
def get_pytorch_module(module_name: str):
"""Retrieving the module that contains a given operator, based on its full name.\n
For example, for "torch.nn.functional.gelu", the function returns module torch.nn.functional."""
repo_operator = pytorch_operator_repository.get_by_name(module_name).full_name
module_name = repo_operator.rsplit(".", 1)[0]
# module = importlib.import_module(module_name) # bad performance
module = torch
if module_name == "torch.nn.functional":
module = torch.nn.functional
return module
Loading