Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo version of ResNet 50 #1133

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions forge/test/mlir/resnet/test_resnet_inference.py

This file was deleted.

145 changes: 0 additions & 145 deletions forge/test/mlir/resnet/test_resnet_unique_ops.py

This file was deleted.

125 changes: 78 additions & 47 deletions forge/test/models/pytorch/vision/resnet/test_resnet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import random

import pytest
import requests
import timm
import torch
from datasets import load_dataset
from loguru import logger
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from tabulate import tabulate
from torchvision.models.resnet import resnet50
from transformers import AutoImageProcessor, ResNetForImageClassification

import forge
from forge.verify.config import VerifyConfig
from forge.verify.value_checkers import AutomaticValueChecker
from forge.verify.verify import verify

from test.models.utils import Framework, Source, Task, build_module_name
Expand All @@ -27,7 +28,9 @@
@pytest.mark.nightly
@pytest.mark.parametrize("variant", variants, ids=variants)
def test_resnet_hf(variant, record_forge_property):
# Record model properties
random.seed(0)

# Record model details
module_name = build_module_name(
framework=Framework.PYTORCH,
model="resnet",
Expand All @@ -37,69 +40,97 @@ def test_resnet_hf(variant, record_forge_property):
)
record_forge_property("model_name", module_name)

# Load dataset
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
# Load tiny dataset
dataset = load_dataset("zh-plus/tiny-imagenet")
images = random.sample(dataset["valid"]["image"], 10)

# Load Torch model, preprocess image, and label dictionary
processor = download_model(AutoImageProcessor.from_pretrained, variant)
# Load framework model
framework_model = download_model(ResNetForImageClassification.from_pretrained, variant, return_dict=False)
label_dict = framework_model.config.id2label

inputs = processor(image, return_tensors="pt")
inputs = inputs["pixel_values"]
# Compile model
input_sample = [torch.rand(1, 3, 224, 224)]
compiled_model = forge.compile(framework_model, input_sample)

compiled_model = forge.compile(framework_model, inputs)
# Verify data on sample input
verify(input_sample, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))

cpu_logits = framework_model(inputs)[0]
cpu_pred = label_dict[cpu_logits.argmax(-1).item()]
# Run model on sample data and print results
run_and_print_results(framework_model, compiled_model, images)

tt_logits = compiled_model(inputs)[0]
tt_pred = label_dict[tt_logits.argmax(-1).item()]

assert cpu_pred == tt_pred, f"Inference mismatch: CPU prediction: {cpu_pred}, TT prediction: {tt_pred}"
def run_and_print_results(framework_model, compiled_model, inputs):
"""
Runs inference using both a framework model and a compiled model on a list of input images,
then prints the results in a formatted table.

verify([inputs], framework_model, compiled_model)
Args:
framework_model: The original framework-based model.
compiled_model: The compiled version of the model.
inputs: A list of images to process and classify.
"""
label_dict = framework_model.config.id2label
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

results = []
for i, image in enumerate(inputs):
processed_inputs = processor(image, return_tensors="pt")["pixel_values"]

def generate_model_resnet_imgcls_timm_pytorch(variant):
# Load ResNet50 feature extractor and model from TIMM
model = download_model(timm.create_model, variant, pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
cpu_logits = framework_model(processed_inputs)[0]
cpu_conf, cpu_idx = cpu_logits.softmax(-1).max(-1)
cpu_pred = label_dict.get(cpu_idx.item(), "Unknown")

# Load data sample
try:
url = "https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIyLTA1L3BkMTA2LTA0Ny1jaGltXzEuanBn.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
except:
logger.warning(
"Failed to download the image file, replacing input with random tensor. Please check if the URL is up to date"
)
image = torch.rand(1, 3, 256, 256)
tt_logits = compiled_model(processed_inputs)[0]
tt_conf, tt_idx = tt_logits.softmax(-1).max(-1)
tt_pred = label_dict.get(tt_idx.item(), "Unknown")

# Data preprocessing
pixel_values = transform(image).unsqueeze(0)
results.append([i + 1, cpu_pred, cpu_conf.item(), tt_pred, tt_conf.item()])

return model, [pixel_values], {}
print(
tabulate(
results,
headers=["Example", "CPU Prediction", "CPU Confidence", "Compiled Prediction", "Compiled Confidence"],
tablefmt="grid",
)
)


@pytest.mark.nightly
def test_resnet_timm(record_forge_property):
pytest.skip("Skipping due to the current CI/CD pipeline limitations")

# Build Module Name
# Record model details
module_name = build_module_name(
framework=Framework.PYTORCH, model="resnet", source=Source.TIMM, variant="50", task=Task.IMAGE_CLASSIFICATION
)
record_forge_property("model_name", module_name)

# Load framework model
framework_model = download_model(timm.create_model, "resnet50", pretrained=True)

# Record Forge Property
# Compile model
input_sample = [torch.rand(1, 3, 224, 224)]
compiled_model = forge.compile(framework_model, sample_inputs=input_sample, module_name=module_name)

# Verify data on sample input
verify(input_sample, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))


@pytest.mark.nightly
def test_resnet_torchvision(record_forge_property):
# Record model details
module_name = build_module_name(
framework=Framework.PYTORCH,
model="resnet",
source=Source.TORCHVISION,
variant="50",
task=Task.IMAGE_CLASSIFICATION,
)
record_forge_property("model_name", module_name)

framework_model, inputs, _ = generate_model_resnet_imgcls_timm_pytorch("resnet50")
# Load framework model
framework_model = resnet50()

# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)
# Compile model
input_sample = [torch.rand(1, 3, 224, 224)]
compiled_model = forge.compile(framework_model, input_sample)

# Model Verification
verify(inputs, framework_model, compiled_model)
# Verify data on sample input
verify(input_sample, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))
4 changes: 0 additions & 4 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ testpaths =
forge/test/mlir/llama/test_llama_inference.py::test_llama_inference
forge/test/mlir/llama/tests

# Resnet
forge/test/mlir/resnet/test_resnet_inference.py::test_resnet_inference
forge/test/mlir/resnet/test_resnet_unique_ops.py

# Benchmark
# MNIST Linear
forge/test/benchmark/benchmark/models/mnist_linear.py::test_mnist_linear
Expand Down
Loading