Skip to content

Commit

Permalink
Add pytorch operators repository
Browse files Browse the repository at this point in the history
  • Loading branch information
vobojevicTT committed Feb 10, 2025
1 parent f0b1bb7 commit 8f0aa44
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 2 deletions.
2 changes: 2 additions & 0 deletions forge/forge/op_repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .datatypes import OperandNumInt, OperandNumTuple, OperandNumRange
from .datatypes import TensorShape, OperatorParam, OperatorParamNumber, OperatorDefinition, OperatorRepository
from .datatypes import ShapeCalculationContext
from .pytorch_operators import pytorch_operator_repository

__ALL__ = [
"OperandNumInt",
Expand All @@ -26,4 +27,5 @@
"OperatorDefinition",
"OperatorRepository",
"ShapeCalculationContext",
"pytorch_operator_repository",
]
99 changes: 99 additions & 0 deletions forge/forge/op_repo/pytorch_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

# PyTorch repostiory operators


from .datatypes import OperatorDefinition, OperatorRepository
from .datatypes import OperatorParamNumber


# TODO describe operand and shapes
_OPERATORS = [
OperatorDefinition(
"linear",
"torch.nn.Linear",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_features", int, 10, 50),
OperatorParamNumber("out_features", int, 10, 50),
],
),
OperatorDefinition(
"conv2d",
"torch.nn.Conv2d",
1,
instantiate=True,
constructor_params=[
OperatorParamNumber("in_channels", int, 10, 50),
OperatorParamNumber("out_channels", int, 10, 50),
OperatorParamNumber("kernel_size", int, 3, 3),
OperatorParamNumber("stride", int, 1, 1),
OperatorParamNumber("padding", int, 1, 1),
],
),
OperatorDefinition("relu", "torch.relu", 1),
OperatorDefinition("sqrt", "torch.sqrt", 1),
OperatorDefinition("reciprocal", "torch.reciprocal", 1),
OperatorDefinition("sigmoid", "torch.sigmoid", 1),
OperatorDefinition("abs", "torch.abs", 1),
OperatorDefinition("cos", "torch.cos", 1),
OperatorDefinition("exp", "torch.exp", 1),
OperatorDefinition("neg", "torch.neg", 1),
OperatorDefinition("rsqrt", "torch.rsqrt", 1),
OperatorDefinition("sin", "torch.sin", 1),
OperatorDefinition("square", "torch.square", 1),
OperatorDefinition("pow", "torch.pow", 1),
OperatorDefinition("clamp", "torch.clamp", 1),
OperatorDefinition("log", "torch.log", 1),
OperatorDefinition("log1p", "torch.log1p", 1),
OperatorDefinition("gelu", "torch.nn.functional.gelu", 1),
OperatorDefinition("leaky_relu", "torch.nn.functional.leaky_relu", 1),
OperatorDefinition("tanh", "torch.tanh", 1),
# OperatorDefinition("add", "torch.add", 1),
OperatorDefinition("add", "torch.add", 2),
OperatorDefinition("sub", "torch.sub", 2),
OperatorDefinition("mul", "torch.mul", 2),
OperatorDefinition("div", "torch.div", 2),
OperatorDefinition("ge", "torch.ge", 2),
# Non-linear activation functions
# HARDTANH = OperatorDefinition("hardtanh", 1)
# HARDWISH = OperatorDefinition("hardwish", 1)
# RELU6 = OperatorDefinition("relu6", 1)
# ELU = OperatorDefinition("elu", 1)
# SELU = OperatorDefinition("selu", 1)
# CELU = OperatorDefinition("celu", 1)
# LEACKY_RELU = OperatorDefinition("leaky_relu", 1)
# PRELU = OperatorDefinition("prelu", 1)
# RRELU = OperatorDefinition("rrelu", 1)
# GLU = OperatorDefinition("glu", 1)
# GELU = OperatorDefinition("gelu", 1)
# LOGSIGMOID = OperatorDefinition("logsigmoid", 1)
# HARDSHRINK = OperatorDefinition("hardshrink", 1)
# TANHSHRINK = OperatorDefinition("tanhshrink", 1)
# SOFTSIGN = OperatorDefinition("softsign", 1)
# SOFTPLUS = OperatorDefinition("softplus", 1)
# SOFTMIN = OperatorDefinition("softmin", 1)
# SOFTMAX = OperatorDefinition("softmax", 1)
# SOFTSHRINK = OperatorDefinition("softshrink", 1)
# GUMBEL_SOFTMAX = OperatorDefinition("gumbel_softmax", 1)
# LOG_SOFTMAX = OperatorDefinition("log_softmax", 1)
# TANH = OperatorDefinition("tanh", 1)
# SIGMOID = OperatorDefinition("sigmoid", 1)
# HARDSIGMOID = OperatorDefinition("hardsigmoid", 1)
# SILU = OperatorDefinition("silu", 1)
# MISH = OperatorDefinition("mish", 1)
# BATCH_NORM = OperatorDefinition("batch_norm", 1)
# GROUP_NORM = OperatorDefinition("group_norm", 1)
# INSTANCE_NORM = OperatorDefinition("instance_norm", 1)
# LAYER_NORM = OperatorDefinition("layer_norm", 1)
# LOCAL_RESPONSE_NORM = OperatorDefinition("local_response_norm", 1)
# NORMALIZE = OperatorDefinition("normalize", 1)
OperatorDefinition("matmul", "torch.matmul", 2),
OperatorDefinition("eltwise", "torch.add", 2),
]


pytorch_operator_repository = OperatorRepository([op for op in _OPERATORS])
45 changes: 43 additions & 2 deletions forge/test/operators/pytorch/eltwise_unary/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# (/) Reuse inputs for selected operators


import torch
import importlib

from typing import List, Dict
from loguru import logger
Expand All @@ -64,6 +64,7 @@
from test.operators.utils import TestCollection
from test.operators.utils import TestCollectionCommon
from test.operators.utils import ValueRanges
from forge.op_repo.pytorch_operators import pytorch_operator_repository

from .models import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass

Expand All @@ -90,7 +91,11 @@ def verify(
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,):
input_source_flag = InputSourceFlags.FROM_DRAM

operator = getattr(torch, test_vector.operator)
# Retrieving the operator from the appropriate module to which it belongs:
repo_operator = pytorch_operator_repository.get_by_name(test_vector.operator).full_name
module_name = repo_operator.rsplit(".", 1)[0]
module = importlib.import_module(module_name)
operator = getattr(module, test_vector.operator)

kwargs = test_vector.kwargs if test_vector.kwargs else {}

Expand Down Expand Up @@ -144,12 +149,25 @@ class TestParamsData:
{"exponent": 10.0},
]

kwargs_gelu = [
{"approximate": "tanh"},
]

kwargs_leaky_relu = [
{"negative_slope": 0.01, "inplace": True},
{"negative_slope": 0.1, "inplace": False},
]

@classmethod
def generate_kwargs(cls, test_vector: TestVector):
if test_vector.operator in ("clamp",):
return cls.kwargs_clamp
if test_vector.operator in ("pow",):
return cls.kwargs_pow
if test_vector.operator in ("gelu",):
return cls.kwargs_gelu
if test_vector.operator in ("leaky_relu",):
return cls.kwargs_leaky_relu
return cls.no_kwargs


Expand Down Expand Up @@ -177,6 +195,8 @@ class TestCollectionData:
# "clip", # alias for clamp
"log",
"log1p",
"gelu",
"leaky_relu",
],
)
not_implemented = TestCollection(
Expand Down Expand Up @@ -282,6 +302,14 @@ class TestCollectionData:
dev_data_formats=TestCollectionCommon.all.dev_data_formats,
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
# Test gelu and leaky_relu operators with default kwargs (no kwargs)
# gelu: approximate='none'
# leaky_relu: negative_slope=0.01, inplace=False
TestCollection(
operators=["gelu", "leaky_relu"],
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
),
],
failing_rules=[
# Skip 2D shapes as we don't test them:
Expand Down Expand Up @@ -687,6 +715,19 @@ class TestCollectionData:
],
failing_reason=FailingReasons.DATA_MISMATCH,
),
# gelu, leaky_relu: not implemented for Int types as shouldn't be implemented:
TestCollection(
operators=["gelu", "leaky_relu"],
dev_data_formats=[
DataFormat.RawUInt8,
DataFormat.RawUInt16,
DataFormat.RawUInt32,
DataFormat.Int8,
DataFormat.UInt16,
DataFormat.Int32,
],
skip_reason=FailingReasons.NOT_IMPLEMENTED,
),
],
)

Expand Down
1 change: 1 addition & 0 deletions forge/test/operators/utils/failing_reasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def validate_exception_message(
in f"{ex}",
lambda ex: isinstance(ex, RuntimeError)
and "info:\nBinaryOpType cannot be mapped to BcastOpMath" in f"{ex}",
lambda ex: isinstance(ex, RuntimeError) and "not implemented for 'Int'" in f"{ex}",
],
FailingReasons.ALLOCATION_FAILED: [
lambda ex: isinstance(ex, RuntimeError)
Expand Down

0 comments on commit 8f0aa44

Please sign in to comment.