Skip to content

Commit

Permalink
Add support for generating nightly test in model analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
chandrasekaranpradeep committed Dec 23, 2024
1 parent a236b80 commit 0f3ca41
Show file tree
Hide file tree
Showing 6 changed files with 485 additions and 46 deletions.
7 changes: 5 additions & 2 deletions forge/forge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def add_parameter(self, name: str, parameter: Parameter, prepend_name: bool = Fa
else:
parameter._set_auto_name(name)

def add_constant(self, name: str, prepend_name: bool = False, shape: Tuple[int] = None):
def add_constant(self, name: str, prepend_name: bool = False, shape: Tuple[int] = None, dtype: torch.dtype = torch.float32, use_random_value: bool = False):
"""
Adds a new constant.
Expand All @@ -741,7 +741,10 @@ def add_constant(self, name: str, prepend_name: bool = False, shape: Tuple[int]
_name = name

if shape:
self._constants[_name] = Tensor.create_from_torch(torch.empty(shape), constant=True)
if use_random_value:
self._constants[_name] = Tensor.create_from_torch(Tensor.create_torch_tensor(shape, dtype), constant=True)
else:
self._constants[_name] = Tensor.create_from_torch(torch.empty(shape), constant=True)
else:
self._constants[_name] = None

Expand Down
4 changes: 4 additions & 0 deletions forge/forge/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
requires_grad: bool = True,
name: str = None,
dev_data_format: Optional[DataFormat] = None,
use_random_value: bool = False,
):
"""
Create parameter of given shape.
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(
self._data_format = pytorch_dtype_to_forge_dataformat(self._value.dtype)
else:
self._data_format = DataFormat.Float32 # default

if self._value is None and use_random_value:
self._value = Tensor.create_torch_tensor(shape=self._tensor_shape.get_pytorch_shape(), dtype=forge_dataformat_to_pytorch_dtype(self._data_format))

def __repr__(self):
ret = f"Forge Parameter {self.get_name()} {self.shape}"
Expand Down
59 changes: 45 additions & 14 deletions forge/forge/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import forge
from forge.tensor import forge_dataformat_to_pytorch_dtype

from typing import Tuple, List
from typing import Tuple, List, Optional, Dict


def forge_df_from_str(df: str, name: str, return_as_str: bool = True):
Expand Down Expand Up @@ -172,7 +172,7 @@ def write_header(self, include_pytest_imports=False):

self.wl("\n")

def write_class_definition(self, params, constants, class_name=None, num_submodels=0, is_submodel=False):
def write_class_definition(self, params, constants, class_name=None, num_submodels=0, is_submodel=False, use_random_value=False):
if class_name is None:
class_name = self.class_name
self.num_submodels = num_submodels
Expand Down Expand Up @@ -200,18 +200,28 @@ def write_class_definition(self, params, constants, class_name=None, num_submode
f'self.add_parameter("{name}", forge.Parameter(*{shape}, requires_grad={requires_grad}, dev_data_format={forge_df_from_str(dtype, name)}), prepend_name=True)'
)
else:
self.wl(
f'self.add_parameter("{name}", forge.Parameter(*{shape}, requires_grad={requires_grad}, dev_data_format={forge_df_from_str(dtype, name)}))'
)
if use_random_value:
self.wl(
f'self.add_parameter("{name}", forge.Parameter(*{shape}, requires_grad={requires_grad}, dev_data_format={forge_df_from_str(dtype, name)}, use_random_value=True))'
)
else:
self.wl(
f'self.add_parameter("{name}", forge.Parameter(*{shape}, requires_grad={requires_grad}, dev_data_format={forge_df_from_str(dtype, name)}))'
)

for const in constants.values():
name = const[0]
shape = tuple(const[1])
dtype = forge_dataformat_to_pytorch_dtype(forge_df_from_str(const[2], name, False))
self.const_names.append(name)
if is_submodel:
self.wl(f'self.add_constant("{name}", prepend_name=True, shape={shape})')
else:
self.wl(f'self.add_constant("{name}", shape={shape})')
if use_random_value:
self.wl(f'self.add_constant("{name}", shape={shape}, dtype={dtype}, use_random_value=True)')
else:
self.wl(f'self.add_constant("{name}", shape={shape})')


self.indent = 0
self.wl("")
Expand Down Expand Up @@ -1024,8 +1034,10 @@ def write_model_parameter_function(self, param_file_name, named_params_file_name
def write_pytest_function(
self,
forge_module_names: List[str],
framework: str,
pytest_input_shapes_and_dtypes_list: List[List[Tuple]],
markers: Optional[List[str]] = [],
module_metadata: Optional[Dict[str, str]] = {},
pytest_metadata_list: Optional[List[Dict[str, str]]] = [],
):
"""
Generates a pytest function that tests modules with input shapes and data types.
Expand All @@ -1041,28 +1053,47 @@ def write_pytest_function(
Args:
forge_module_names (List[str]): List of names of the modules to be tested, each corresponding to a forge module.
framework (str): The name of the framework under which the model is to be tested (e.g., "pytorch").
pytest_input_shapes_and_dtypes_list (List[List[Tuple]]): A list of input shapes and corresponding data types for each module. Each tuple contains the shape and dtype to be tested.
"""
self.wl("")
self.wl("")
self.wl("forge_modules_and_shapes_dtypes_list = [")
self.indent += 1
for forge_module_name, pytest_input_shapes_and_dtypes in zip(
forge_module_names, pytest_input_shapes_and_dtypes_list
is_pytest_metadata_list_empty = False
if len(pytest_metadata_list) == 0:
pytest_metadata_list = [{}] * len(pytest_input_shapes_and_dtypes_list)
is_pytest_metadata_list_empty = True
for forge_module_name, pytest_input_shapes_and_dtypes, pytest_metadata in zip(
forge_module_names, pytest_input_shapes_and_dtypes_list, pytest_metadata_list
):
pytest_input_shapes_and_dtypes = [
(shape, forge_dataformat_to_pytorch_dtype(forge_df_from_str(dtype, "", False)))
for shape, dtype in pytest_input_shapes_and_dtypes
]
self.wl(f"({forge_module_name}, {pytest_input_shapes_and_dtypes}), ")
if len(pytest_metadata) == 0:
self.wl(f"({forge_module_name}, {pytest_input_shapes_and_dtypes}), ")
else:
self.wl(f"({forge_module_name}, {pytest_input_shapes_and_dtypes}, {pytest_metadata}), ")
self.indent -= 1
self.wl("]")
for marker in markers:
self.wl(f"@pytest.mark.{marker}")
self.wl('@pytest.mark.parametrize("forge_module_and_shapes_dtypes", forge_modules_and_shapes_dtypes_list)')
self.wl("def test_module(forge_module_and_shapes_dtypes):")
self.wl("def test_module(forge_module_and_shapes_dtypes, record_property):")
self.indent += 1
if len(module_metadata) != 0:
for metadata_name, metadata_value in module_metadata.items():
self.wl(f'record_property("{metadata_name}", "{metadata_value}")')
self.wl("")
self.wl("forge_module, operand_shapes_dtypes = forge_module_and_shapes_dtypes")
if is_pytest_metadata_list_empty:
self.wl("forge_module, operand_shapes_dtypes = forge_module_and_shapes_dtypes")
else:
self.wl("forge_module, operand_shapes_dtypes, metadata = forge_module_and_shapes_dtypes")
self.wl("")
self.wl("for metadata_name, metadata_value in metadata.items():")
self.indent += 1
self.wl(f'record_property(metadata_name, metadata_value)')
self.indent -= 1
self.wl("")
need_model_parameter_function = any(
[
Expand All @@ -1085,7 +1116,7 @@ def write_pytest_function(
self.wl("")
self.wl("compiled_model = compile(framework_model, sample_inputs=inputs)")
self.wl("")
self.wl("verify(inputs, framework_model, compiled_model, VerifyConfig(verify_allclose=False))")
self.wl("verify(inputs, framework_model, compiled_model)")
self.wl("")
self.wl("")
self.indent -= 1
Expand Down
24 changes: 18 additions & 6 deletions forge/forge/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,23 @@ def create_from_torch(
"""
return TensorFromPytorch(torch_tensor, dev_data_format, constant)

@classmethod
def create_torch_tensor(
cls,
shape: Union[List, Tuple, torch.Size],
dtype: Optional[torch.dtype] = None,
integer_high_value: int = 1000,
) -> torch.Tensor:

if dtype in [torch.float16, torch.bfloat16, torch.float32]:
torch_tensor = torch.rand(shape, dtype=dtype)
elif dtype in [torch.int8, torch.int, torch.int32]:
torch_tensor = torch.randint(high=integer_high_value, size=shape, dtype=dtype)
else:
torch_tensor = torch.rand(shape, dtype=torch.float32)

return torch_tensor

@classmethod
def create_from_shape(
cls,
Expand All @@ -283,12 +300,7 @@ def create_from_shape(
constant: bool = False,
) -> "TensorFromPytorch":

if torch_dtype in [torch.float16, torch.bfloat16, torch.float32]:
torch_tensor = torch.rand(tensor_shape, dtype=torch_dtype)
elif torch_dtype in [torch.int8, torch.int, torch.int32]:
torch_tensor = torch.randint(high=integer_tensor_high_value, size=tensor_shape, dtype=torch_dtype)
else:
torch_tensor = torch.rand(tensor_shape, dtype=torch.float32)
torch_tensor = Tensor.create_torch_tensor(shape=tensor_shape, dtype=torch_dtype, integer_high_value = 1000)

return TensorFromPytorch(
torch_tensor, dev_data_format=pytorch_dtype_to_forge_dataformat(torch_dtype), constant=constant
Expand Down
Loading

0 comments on commit 0f3ca41

Please sign in to comment.