Skip to content

Commit

Permalink
Extract gely and leaky_rely test collections
Browse files Browse the repository at this point in the history
  • Loading branch information
vobojevicTT committed Feb 12, 2025
1 parent 3a81297 commit 6d20381
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 18 deletions.
78 changes: 61 additions & 17 deletions forge/test/operators/pytorch/eltwise_unary/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class TestParamsData:
__test__ = False

test_plan_implemented: TestPlan = None
test_plan_implemented_float: TestPlan = None
test_plan_not_implemented: TestPlan = None

no_kwargs = [
Expand All @@ -146,11 +147,13 @@ class TestParamsData:

kwargs_gelu = [
{"approximate": "tanh"},
{},
]

kwargs_leaky_relu = [
{"negative_slope": 0.01, "inplace": True},
{"negative_slope": 0.1, "inplace": False},
{},
]

@classmethod
Expand Down Expand Up @@ -190,6 +193,10 @@ class TestCollectionData:
# "clip", # alias for clamp
"log",
"log1p",
],
)
implemented_float = TestCollection(
operators=[
"gelu",
"leaky_relu",
],
Expand Down Expand Up @@ -297,14 +304,6 @@ class TestCollectionData:
dev_data_formats=TestCollectionCommon.all.dev_data_formats,
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
# Test gelu and leaky_relu operators with default kwargs (no kwargs)
# gelu: approximate='none'
# leaky_relu: negative_slope=0.01, inplace=False
TestCollection(
operators=["gelu", "leaky_relu"],
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
),
],
failing_rules=[
# Skip 2D shapes as we don't test them:
Expand Down Expand Up @@ -710,18 +709,62 @@ class TestCollectionData:
],
failing_reason=FailingReasons.DATA_MISMATCH,
),
# gelu, leaky_relu: not implemented for Int types as shouldn't be implemented:
],
)


TestParamsData.test_plan_implemented_float = TestPlan(
verify=lambda test_device, test_vector: TestVerification.verify(
test_device,
test_vector,
),
collections=[
# Test gelu, leaky_relu operators collection:
TestCollection(
operators=TestCollectionData.implemented_float.operators,
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
),
# Test gelu, leaky_relu data formats collection:
TestCollection(
operators=["gelu", "leaky_relu"],
operators=TestCollectionData.implemented_float.operators,
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
dev_data_formats=[
DataFormat.RawUInt8,
DataFormat.RawUInt16,
DataFormat.RawUInt32,
DataFormat.Int8,
DataFormat.UInt16,
DataFormat.Int32,
item
for item in TestCollectionCommon.float.dev_data_formats
if item not in TestCollectionCommon.single.dev_data_formats
],
skip_reason=FailingReasons.NOT_IMPLEMENTED,
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
# Test gelu, leaky_relu math fidelities collection:
TestCollection(
operators=TestCollectionData.implemented_float.operators,
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
dev_data_formats=TestCollectionCommon.single.dev_data_formats,
math_fidelities=TestCollectionCommon.all.math_fidelities,
),
],
failing_rules=[
TestCollection(
operators=["gelu"],
input_shapes=[(1, 1)],
kwargs=[
{"approximate": "tanh"},
{},
],
failing_reason=FailingReasons.DATA_MISMATCH,
),
TestCollection(
operators=["leaky_relu"],
input_sources=[InputSource.CONST_EVAL_PASS],
input_shapes=[(1, 1)],
kwargs=[{"negative_slope": 0.01, "inplace": True}],
failing_reason=FailingReasons.DATA_MISMATCH,
),
],
)
Expand Down Expand Up @@ -754,5 +797,6 @@ class TestCollectionData:
def get_test_plans() -> List[TestPlan]:
return [
TestParamsData.test_plan_implemented,
TestParamsData.test_plan_implemented_float,
TestParamsData.test_plan_not_implemented,
]
1 change: 0 additions & 1 deletion forge/test/operators/utils/failing_reasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def validate_exception_message(
in f"{ex}",
lambda ex: isinstance(ex, RuntimeError)
and "info:\nBinaryOpType cannot be mapped to BcastOpMath" in f"{ex}",
lambda ex: isinstance(ex, RuntimeError) and "not implemented for 'Int'" in f"{ex}",
],
FailingReasons.ALLOCATION_FAILED: [
lambda ex: isinstance(ex, RuntimeError)
Expand Down

0 comments on commit 6d20381

Please sign in to comment.