Skip to content

Commit

Permalink
Skip forge verification
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT committed Jan 30, 2025
1 parent 8573421 commit c2469c1
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/src/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ Full list of supported query parameters
| TEST_ID | Id of a test containing test parameters | test_single |
| ID_FILE | Path to a file containing test ids | test_ids |

Test configuration parameters

| Parameter | Description | Supported by commands |
| ------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------- |
| SKIP_FORGE_VERIFICATION | Skip Forge model verification including model compiling and inference | all |

To check supported values and options for each query parameter please run command `print_query_docs`.

Expand Down
26 changes: 26 additions & 0 deletions forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def print_query_params(cls, max_width=80):
cls.print_query_values(max_width)
print("Query examples:")
cls.print_query_examples(max_width)
print("Configuration parameters:")
cls.print_configuration_params(max_width)
print("Configuration examples:")
cls.print_configuration_examples(max_width)

@classmethod
def print_query_values(cls, max_width=80):
Expand Down Expand Up @@ -500,6 +504,28 @@ def print_query_examples(cls, max_width=80):

cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Examples"])

@classmethod
def print_configuration_params(cls, max_width=80):

parameters = [
{
"name": "SKIP_FORGE_VERIFICATION",
"description": f"Skip Forge model verification including model compiling and inference",
"default": "false",
},
]

cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Description", "Default"])

@classmethod
def print_configuration_examples(cls, max_width=80):

parameters = [
{"name": "SKIP_FORGE_VERIFICATION", "description": "export SKIP_FORGE_VERIFICATION=true"},
]

cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Examples"])

@classmethod
def print_formatted_parameters(cls, parameters, max_width=80, headers=["Parameter", "Description"]):
for param in parameters:
Expand Down
2 changes: 2 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .utils import LoggerUtils
from .utils import RateLimiter
from .utils import FrameworkModelType
from .features import TestFeaturesConfiguration
from .plan import InputSource
from .plan import TestVector
from .plan import TestCollection
Expand Down Expand Up @@ -41,6 +42,7 @@
"VerifyUtils",
"LoggerUtils",
"RateLimiter",
"TestFeaturesConfiguration",
"FrameworkModelType",
"InputSource",
"TestVector",
Expand Down
50 changes: 50 additions & 0 deletions forge/test/operators/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, List, Union

from forge import ForgeModule, Module, DepricatedVerifyConfig
from forge.tensor import to_pt_tensors
from forge.op_repo import TensorShape
from forge.verify.compare import compare_with_golden
from forge.verify.verify import verify
Expand Down Expand Up @@ -326,3 +327,52 @@ def verify_module_for_inputs(
forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
compiled_model = forge.compile(model, sample_inputs=forge_inputs)
verify(inputs, model, compiled_model, verify_config)


def verify_module_for_inputs_torch(
model: Module,
inputs: List[torch.Tensor],
verify_config: Optional[VerifyConfig] = VerifyConfig(),
):

verify_torch(inputs, model, verify_config)


def verify_torch(
inputs: List[torch.Tensor],
framework_model: torch.nn.Module,
verify_cfg: VerifyConfig = VerifyConfig(),
):
"""
Verify the pytorch model with the given inputs
"""
if not verify_cfg.enabled:
logger.warning("Verification is disabled")
return

# 0th step: input checks

# Check if inputs are of the correct type
if not inputs:
raise ValueError("Input tensors must be provided")
for input_tensor in inputs:
if not isinstance(input_tensor, verify_cfg.supported_tensor_types):
raise TypeError(
f"Input tensor must be of type {verify_cfg.supported_tensor_types}, but got {type(input_tensor)}"
)

if not isinstance(framework_model, verify_cfg.framework_model_types):
raise TypeError(
f"Framework model must be of type {verify_cfg.framework_model_types}, but got {type(framework_model)}"
)

# 1st step: run forward pass for the networks
fw_out = framework_model(*inputs)

# 2nd step: apply preprocessing (push tensors to cpu, perform any reshape if necessary,
# cast from tensorflow tensors to pytorch tensors if needed)
if not isinstance(fw_out, torch.Tensor):
fw_out = to_pt_tensors(fw_out)

fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
return fw_out
17 changes: 17 additions & 0 deletions forge/test/operators/utils/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import os


class TestFeaturesConfiguration:
"""Store test features configuration"""

__test__ = False # Disable pytest collection

@staticmethod
def get_env_property(env_var: str, default_value: str):
return os.getenv(env_var, default_value)

SKIP_FORGE_VERIFICATION = get_env_property("SKIP_FORGE_VERIFICATION", "false").lower() == "true"
17 changes: 16 additions & 1 deletion forge/test/operators/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@
from forge.verify.config import VerifyConfig

from .compat import TestDevice
from .compat import create_torch_inputs, verify_module_for_inputs, verify_module_for_inputs_deprecated
from .compat import (
create_torch_inputs,
verify_module_for_inputs,
verify_module_for_inputs_deprecated,
verify_module_for_inputs_torch,
)
from .datatypes import ValueRanges
from .features import TestFeaturesConfiguration


# All supported framework model types
Expand Down Expand Up @@ -130,6 +136,7 @@ def verify(
warm_reset: bool = False,
deprecated_verification: bool = True,
verify_config: Optional[VerifyConfig] = VerifyConfig(),
skip_forge_verification: bool = TestFeaturesConfiguration.SKIP_FORGE_VERIFICATION,
):
"""Perform Forge verification on the model
Expand All @@ -146,6 +153,8 @@ def verify(
random_seed: Random seed
warm_reset: Warm reset the device before verification
deprecated_verification: Use deprecated verification method
verify_config: Verification configuration
skip_forge_verification: Skip verification with Forge module
"""

cls.setup(
Expand All @@ -168,6 +177,12 @@ def verify(
pcc=pcc,
dev_data_format=dev_data_format,
)
elif skip_forge_verification:
verify_module_for_inputs_torch(
model=model,
inputs=inputs,
verify_config=verify_config,
)
else:
cls.verify_module_for_inputs(
model=model,
Expand Down

0 comments on commit c2469c1

Please sign in to comment.