Skip to content

Commit

Permalink
Update ResNet50 E2E test to represent a valid demo version with actua…
Browse files Browse the repository at this point in the history
…l predictions (#1126)

### Ticket

### Problem description
Update ResNet50 E2E test to represent a valid demo version with actual
predictions

### What's changed
- Loading dataset (currently single sample) & labels dictionary
- Comparing CPU vs TT label predictions

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
nvukobratTT authored Jan 29, 2025
1 parent 4a19bbe commit d407988
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions forge/test/models/pytorch/vision/resnet/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,65 +5,61 @@
import requests
import timm
import torch
from datasets import load_dataset
from loguru import logger
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from transformers import AutoFeatureExtractor, ResNetForImageClassification
from transformers import AutoImageProcessor, ResNetForImageClassification

import forge
from forge.verify.verify import verify

from test.models.utils import Framework, Source, Task, build_module_name
from test.utils import download_model


def generate_model_resnet_imgcls_hf_pytorch(variant):
# Load ResNet feature extractor and model checkpoint from HuggingFace
model_ckpt = variant
feature_extractor = download_model(AutoFeatureExtractor.from_pretrained, model_ckpt)
model = download_model(ResNetForImageClassification.from_pretrained, model_ckpt)

# Load data sample
try:
url = "https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIyLTA1L3BkMTA2LTA0Ny1jaGltXzEuanBn.jpg"
image = Image.open(requests.get(url, stream=True).raw)
except:
logger.warning(
"Failed to download the image file, replacing input with random tensor. Please check if the URL is up to date"
)
image = torch.rand(1, 3, 256, 256)

# Data preprocessing
inputs = feature_extractor(image, return_tensors="pt")
pixel_values = inputs["pixel_values"]

return model, [pixel_values], {}
variants = [
"microsoft/resnet-50",
]


@pytest.mark.push
@pytest.mark.nightly
def test_resnet(record_forge_property):
# Build Module Name
@pytest.mark.parametrize("variant", variants, ids=variants)
def test_resnet_hf(variant, record_forge_property):
# Record model properties
module_name = build_module_name(
framework=Framework.PYTORCH,
model="resnet",
variant="50",
source=Source.HUGGINGFACE,
task=Task.IMAGE_CLASSIFICATION,
)

# Record Forge Property
record_forge_property("model_name", module_name)

framework_model, inputs, _ = generate_model_resnet_imgcls_hf_pytorch(
"microsoft/resnet-50",
)
# Load dataset
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)
# Load Torch model, preprocess image, and label dictionary
processor = download_model(AutoImageProcessor.from_pretrained, variant)
framework_model = download_model(ResNetForImageClassification.from_pretrained, variant, return_dict=False)
label_dict = framework_model.config.id2label

# Model Verification
verify(inputs, framework_model, compiled_model)
inputs = processor(image, return_tensors="pt")
inputs = inputs["pixel_values"]

compiled_model = forge.compile(framework_model, inputs)

cpu_logits = framework_model(inputs)[0]
cpu_pred = label_dict[cpu_logits.argmax(-1).item()]

tt_logits = compiled_model(inputs)[0]
tt_pred = label_dict[tt_logits.argmax(-1).item()]

assert cpu_pred == tt_pred, f"Inference mismatch: CPU prediction: {cpu_pred}, TT prediction: {tt_pred}"

verify([inputs], framework_model, compiled_model)


def generate_model_resnet_imgcls_timm_pytorch(variant):
Expand Down

0 comments on commit d407988

Please sign in to comment.