Skip to content

Commit

Permalink
Add gelu and leaky_relu operators tests [skip ci]
Browse files Browse the repository at this point in the history
 - Add pytorch operators repository
 - Add PytorchUtils class
 - Add gelu and leaky_relu operators tests
  • Loading branch information
vobojevicTT committed Feb 11, 2025
1 parent 2102aff commit c124c67
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 3 deletions.
2 changes: 2 additions & 0 deletions forge/forge/op_repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange
from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository
from .datatypes import ShapeCalculationContext
from .pytorch_operators import pytorch_operator_repository

__ALL__ = [
"OperandNumInt",
Expand All @@ -26,4 +27,5 @@
"OperatorDefinition",
"OperatorRepository",
"ShapeCalculationContext",
"pytorch_operator_repository",
]
99 changes: 99 additions & 0 deletions forge/forge/op_repo/pytorch_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

# PyTorch repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [
OperatorDefinition(
"linear",
"torch.nn.Linear",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_features", int, 10, 50),
OperatorParamNumber("out_features", int, 10, 50),
],
),
OperatorDefinition(
"conv2d",
"torch.nn.Conv2d",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_channels", int, 10, 50),
OperatorParamNumber("out_channels", int, 10, 50),
OperatorParamNumber("kernel_size", int, 3, 3),
OperatorParamNumber("stride", int, 1, 1),
OperatorParamNumber("padding", int, 1, 1),
],
),
OperatorDefinition("relu", "torch.relu", 1),
OperatorDefinition("sqrt", "torch.sqrt", 1),
OperatorDefinition("reciprocal", "torch.reciprocal", 1),
OperatorDefinition("sigmoid", "torch.sigmoid", 1),
OperatorDefinition("abs", "torch.abs", 1),
OperatorDefinition("cos", "torch.cos", 1),
OperatorDefinition("exp", "torch.exp", 1),
OperatorDefinition("neg", "torch.neg", 1),
OperatorDefinition("rsqrt", "torch.rsqrt", 1),
OperatorDefinition("sin", "torch.sin", 1),
OperatorDefinition("square", "torch.square", 1),
OperatorDefinition("pow", "torch.pow", 1),
OperatorDefinition("clamp", "torch.clamp", 1),
OperatorDefinition("log", "torch.log", 1),
OperatorDefinition("log1p", "torch.log1p", 1),
OperatorDefinition("gelu", "torch.nn.functional.gelu", 1),
OperatorDefinition("leaky_relu", "torch.nn.functional.leaky_relu", 1),
OperatorDefinition("tanh", "torch.tanh", 1),
# OperatorDefinition("add", "torch.add", 1),
OperatorDefinition("add", "torch.add", 2),
OperatorDefinition("sub", "torch.sub", 2),
OperatorDefinition("mul", "torch.mul", 2),
OperatorDefinition("div", "torch.div", 2),
OperatorDefinition("ge", "torch.ge", 2),
# Non-linear activation functions
# HARDTANH = OperatorDefinition("hardtanh", 1)
# HARDWISH = OperatorDefinition("hardwish", 1)
# RELU6 = OperatorDefinition("relu6", 1)
# ELU = OperatorDefinition("elu", 1)
# SELU = OperatorDefinition("selu", 1)
# CELU = OperatorDefinition("celu", 1)
# LEACKY_RELU = OperatorDefinition("leaky_relu", 1)
# PRELU = OperatorDefinition("prelu", 1)
# RRELU = OperatorDefinition("rrelu", 1)
# GLU = OperatorDefinition("glu", 1)
# GELU = OperatorDefinition("gelu", 1)
# LOGSIGMOID = OperatorDefinition("logsigmoid", 1)
# HARDSHRINK = OperatorDefinition("hardshrink", 1)
# TANHSHRINK = OperatorDefinition("tanhshrink", 1)
# SOFTSIGN = OperatorDefinition("softsign", 1)
# SOFTPLUS = OperatorDefinition("softplus", 1)
# SOFTMIN = OperatorDefinition("softmin", 1)
# SOFTMAX = OperatorDefinition("softmax", 1)
# SOFTSHRINK = OperatorDefinition("softshrink", 1)
# GUMBEL_SOFTMAX = OperatorDefinition("gumbel_softmax", 1)
# LOG_SOFTMAX = OperatorDefinition("log_softmax", 1)
# TANH = OperatorDefinition("tanh", 1)
# SIGMOID = OperatorDefinition("sigmoid", 1)
# HARDSIGMOID = OperatorDefinition("hardsigmoid", 1)
# SILU = OperatorDefinition("silu", 1)
# MISH = OperatorDefinition("mish", 1)
# BATCH_NORM = OperatorDefinition("batch_norm", 1)
# GROUP_NORM = OperatorDefinition("group_norm", 1)
# INSTANCE_NORM = OperatorDefinition("instance_norm", 1)
# LAYER_NORM = OperatorDefinition("layer_norm", 1)
# LOCAL_RESPONSE_NORM = OperatorDefinition("local_response_norm", 1)
# NORMALIZE = OperatorDefinition("normalize", 1)
OperatorDefinition("matmul", "torch.matmul", 2),
OperatorDefinition("eltwise", "torch.add", 2),
]


pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS])
42 changes: 39 additions & 3 deletions forge/test/operators/pytorch/eltwise_unary/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,13 @@
# (/) Reuse inputs for selected operators


import torch

from typing import List, Dict
from loguru import logger
from forge import MathFidelity, DataFormat

from test.operators.utils import InputSourceFlags, VerifyUtils
from test.operators.utils import InputSource
from test.operators.utils import PytorchUtils
from test.operators.utils import TestVector
from test.operators.utils import TestPlan
from test.operators.utils import FailingReasons
Expand Down Expand Up @@ -90,7 +89,8 @@ def verify(
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,):
input_source_flag = InputSourceFlags.FROM_DRAM

operator = getattr(torch, test_vector.operator)
module = PytorchUtils.get_pytorch_module(test_vector.operator)
operator = getattr(module, test_vector.operator)

kwargs = test_vector.kwargs if test_vector.kwargs else {}

Expand Down Expand Up @@ -144,12 +144,25 @@ class TestParamsData:
{"exponent": 10.0},
]

kwargs_gelu = [
{"approximate": "tanh"},
]

kwargs_leaky_relu = [
{"negative_slope": 0.01, "inplace": True},
{"negative_slope": 0.1, "inplace": False},
]

@classmethod
def generate_kwargs(cls, test_vector: TestVector):
if test_vector.operator in ("clamp",):
return cls.kwargs_clamp
if test_vector.operator in ("pow",):
return cls.kwargs_pow
if test_vector.operator in ("gelu",):
return cls.kwargs_gelu
if test_vector.operator in ("leaky_relu",):
return cls.kwargs_leaky_relu
return cls.no_kwargs


Expand Down Expand Up @@ -177,6 +190,8 @@ class TestCollectionData:
# "clip", # alias for clamp
"log",
"log1p",
"gelu",
"leaky_relu",
],
)
not_implemented = TestCollection(
Expand Down Expand Up @@ -282,6 +297,14 @@ class TestCollectionData:
dev_data_formats=TestCollectionCommon.all.dev_data_formats,
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
# Test gelu and leaky_relu operators with default kwargs (no kwargs)
# gelu: approximate='none'
# leaky_relu: negative_slope=0.01, inplace=False
TestCollection(
operators=["gelu", "leaky_relu"],
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
),
],
failing_rules=[
# Skip 2D shapes as we don't test them:
Expand Down Expand Up @@ -687,6 +710,19 @@ class TestCollectionData:
],
failing_reason=FailingReasons.DATA_MISMATCH,
),
# gelu, leaky_relu: not implemented for Int types as shouldn't be implemented:
TestCollection(
operators=["gelu", "leaky_relu"],
dev_data_formats=[
DataFormat.RawUInt8,
DataFormat.RawUInt16,
DataFormat.RawUInt32,
DataFormat.Int8,
DataFormat.UInt16,
DataFormat.Int32,
],
skip_reason=FailingReasons.NOT_IMPLEMENTED,
),
],
)

Expand Down
2 changes: 2 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .utils import LoggerUtils
from .utils import RateLimiter
from .utils import FrameworkModelType
from .utils import PytorchUtils
from .features import TestFeaturesConfiguration
from .plan import InputSource
from .plan import TestVector
Expand Down Expand Up @@ -47,6 +48,7 @@
"RateLimiter",
"TestFeaturesConfiguration",
"FrameworkModelType",
"PytorchUtils",
"InputSource",
"TestVector",
"TestCollection",
Expand Down
1 change: 1 addition & 0 deletions forge/test/operators/utils/failing_reasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def validate_exception_message(
in f"{ex}",
lambda ex: isinstance(ex, RuntimeError)
and "info:\nBinaryOpType cannot be mapped to BcastOpMath" in f"{ex}",
lambda ex: isinstance(ex, RuntimeError) and "not implemented for 'Int'" in f"{ex}",
],
FailingReasons.ALLOCATION_FAILED: [
lambda ex: isinstance(ex, RuntimeError)
Expand Down
17 changes: 17 additions & 0 deletions forge/test/operators/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from forge import ForgeModule, Module, DepricatedVerifyConfig
from forge.op_repo import TensorShape
from forge.op_repo.pytorch_operators import pytorch_operator_repository
from forge.verify import TestKind # , verify_module
from forge._C import MathFidelity

Expand Down Expand Up @@ -315,3 +316,19 @@ def limit_info(self) -> str:
return f"{self.current_value} <= {self.current_limit}"
else:
return f"{self.current_value} > {self.current_limit}"


class PytorchUtils:
"""Utility functions for PyTorch operators"""

@staticmethod
def get_pytorch_module(module_name: str):
"""Retrieving the module that contains a given operator, based on its full name.\n
For example, for "torch.nn.functional.gelu", the function returns module torch.nn.functional."""
repo_operator = pytorch_operator_repository.get_by_name(module_name).full_name
module_name = repo_operator.rsplit(".", 1)[0]
# module = importlib.import_module(module_name) # bad performance
module = torch
if module_name == "torch.nn.functional":
module = torch.nn.functional
return module

0 comments on commit c124c67

Please sign in to comment.