diff --git a/forge/test/operators/pytorch/eltwise_unary/test_unary.py b/forge/test/operators/pytorch/eltwise_unary/test_unary.py index bb05610a7..3a69b7dc4 100644 --- a/forge/test/operators/pytorch/eltwise_unary/test_unary.py +++ b/forge/test/operators/pytorch/eltwise_unary/test_unary.py @@ -125,6 +125,7 @@ class TestParamsData: __test__ = False test_plan_implemented: TestPlan = None + test_plan_implemented_float: TestPlan = None test_plan_not_implemented: TestPlan = None no_kwargs = [ @@ -146,11 +147,13 @@ class TestParamsData: kwargs_gelu = [ {"approximate": "tanh"}, + {}, ] kwargs_leaky_relu = [ {"negative_slope": 0.01, "inplace": True}, {"negative_slope": 0.1, "inplace": False}, + {}, ] @classmethod @@ -190,6 +193,10 @@ class TestCollectionData: # "clip", # alias for clamp "log", "log1p", + ], + ) + implemented_float = TestCollection( + operators=[ "gelu", "leaky_relu", ], @@ -297,14 +304,6 @@ class TestCollectionData: dev_data_formats=TestCollectionCommon.all.dev_data_formats, math_fidelities=TestCollectionCommon.single.math_fidelities, ), - # Test gelu and leaky_relu operators with default kwargs (no kwargs) - # gelu: approximate='none' - # leaky_relu: negative_slope=0.01, inplace=False - TestCollection( - operators=["gelu", "leaky_relu"], - input_sources=TestCollectionCommon.all.input_sources, - input_shapes=TestCollectionCommon.all.input_shapes, - ), ], failing_rules=[ # Skip 2D shapes as we don't test them: @@ -710,18 +709,62 @@ class TestCollectionData: ], failing_reason=FailingReasons.DATA_MISMATCH, ), - # gelu, leaky_relu: not implemented for Int types as shouldn't be implemented: + ], +) + + +TestParamsData.test_plan_implemented_float = TestPlan( + verify=lambda test_device, test_vector: TestVerification.verify( + test_device, + test_vector, + ), + collections=[ + # Test gelu, leaky_relu operators collection: + TestCollection( + operators=TestCollectionData.implemented_float.operators, + input_sources=TestCollectionCommon.all.input_sources, + input_shapes=TestCollectionCommon.all.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + ), + # Test gelu, leaky_relu data formats collection: TestCollection( - operators=["gelu", "leaky_relu"], + operators=TestCollectionData.implemented_float.operators, + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), dev_data_formats=[ - DataFormat.RawUInt8, - DataFormat.RawUInt16, - DataFormat.RawUInt32, - DataFormat.Int8, - DataFormat.UInt16, - DataFormat.Int32, + item + for item in TestCollectionCommon.float.dev_data_formats + if item not in TestCollectionCommon.single.dev_data_formats ], - skip_reason=FailingReasons.NOT_IMPLEMENTED, + math_fidelities=TestCollectionCommon.single.math_fidelities, + ), + # Test gelu, leaky_relu math fidelities collection: + TestCollection( + operators=TestCollectionData.implemented_float.operators, + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=TestCollectionCommon.single.dev_data_formats, + math_fidelities=TestCollectionCommon.all.math_fidelities, + ), + ], + failing_rules=[ + TestCollection( + operators=["gelu"], + input_shapes=[(1, 1)], + kwargs=[ + {"approximate": "tanh"}, + {}, + ], + failing_reason=FailingReasons.DATA_MISMATCH, + ), + TestCollection( + operators=["leaky_relu"], + input_sources=[InputSource.CONST_EVAL_PASS], + input_shapes=[(1, 1)], + kwargs=[{"negative_slope": 0.01, "inplace": True}], + failing_reason=FailingReasons.DATA_MISMATCH, ), ], ) @@ -754,5 +797,6 @@ class TestCollectionData: def get_test_plans() -> List[TestPlan]: return [ TestParamsData.test_plan_implemented, + TestParamsData.test_plan_implemented_float, TestParamsData.test_plan_not_implemented, ] diff --git a/forge/test/operators/utils/failing_reasons.py b/forge/test/operators/utils/failing_reasons.py index b18ecf4aa..4c456233e 100644 --- a/forge/test/operators/utils/failing_reasons.py +++ b/forge/test/operators/utils/failing_reasons.py @@ -123,7 +123,6 @@ def validate_exception_message( in f"{ex}", lambda ex: isinstance(ex, RuntimeError) and "info:\nBinaryOpType cannot be mapped to BcastOpMath" in f"{ex}", - lambda ex: isinstance(ex, RuntimeError) and "not implemented for 'Int'" in f"{ex}", ], FailingReasons.ALLOCATION_FAILED: [ lambda ex: isinstance(ex, RuntimeError)