diff --git a/forge/test/operators/pytorch/conftest.py b/forge/test/operators/pytorch/conftest.py index dd012db3f..c7e1501c8 100644 --- a/forge/test/operators/pytorch/conftest.py +++ b/forge/test/operators/pytorch/conftest.py @@ -58,14 +58,21 @@ def pytest_runtest_makereport(item: _pytest.python.Function, call: _pytest.runne if report.when == "call" or (report.when == "setup" and report.skipped): try: - log_test_vector_properties(item, report, xfail_reason) + log_test_vector_properties( + item=item, + report=report, + xfail_reason=xfail_reason, + exception=call.excinfo.value if call.excinfo is not None else None, + ) except Exception as e: logger.error(f"Failed to log test vector properties: {e}") logger.exception(e) pass -def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str): +def log_test_vector_properties( + item: _pytest.python.Function, report: _pytest.reports.TestReport, xfail_reason: str, exception: Exception +): original_name = item.originalname test_id = item.name test_id = test_id.replace(f"{original_name}[", "") @@ -91,3 +98,21 @@ def log_test_vector_properties(item: _pytest.python.Function, report: _pytest.re if xfail_reason is not None: item.user_properties.append(("xfail_reason", xfail_reason)) item.user_properties.append(("outcome", report.outcome)) + + if exception is not None: + error_message = f"{exception}" + + if "Observed maximum relative diff" in error_message: + error_message_lines = error_message.split("\n") + observed_error_lines = [line for line in error_message_lines if "Observed maximum relative diff" in line] + if observed_error_lines: + observed_error_line = observed_error_lines[0] + # Example: "- Observed maximum relative diff: 0.0008770461427047849, maximum absolute diff: 0.0009063482284545898" + rtol = float(observed_error_line.split(",")[0].split(":")[1].strip()) + atol = float(observed_error_line.split(",")[1].split(":")[1].strip()) + else: + logger.error(f"Error parsing 'Observed maximum relative diff' from the exception: {error_message}") + rtol = None + atol = None + item.user_properties.append(("all_close_rtol", rtol)) + item.user_properties.append(("all_close_atol", atol))