Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pcc allclose tunable for each test. Add test for MLP training. #189

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions tests/infra/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@ class AtolConfig(ConfigBase):
required_atol: float = 1.6e-1


@dataclass
class PccConfig(ConfigBase):
required_pcc: float = 0.99


@dataclass
class AllcloseConfig(ConfigBase):
rtol: float = 1e-2
atol: float = 1e-2


# When tensors are too close, pcc will output NaN values.
# Therefore, for each test it should be possible to separately tune the threshold of allclose.rtol and allclose.atol
# below which pcc won't be calculated and therefore test will be able to pass without pcc comparison.
@dataclass
class PccConfig(ConfigBase):
required_pcc: float = 0.99
allclose: AllcloseConfig = AllcloseConfig()
umalesTT marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class ComparisonConfig:
equal: EqualConfig = EqualConfig(False)
Expand Down Expand Up @@ -106,9 +110,7 @@ def compare_pcc(

# If tensors are really close, pcc will be nan. Handle that before calculating pcc.
try:
compare_allclose(
device_output, golden_output, AllcloseConfig(rtol=1e-2, atol=1e-2)
)
compare_allclose(device_output, golden_output, pcc_config.allclose)
except AssertionError:
pcc = jnp.corrcoef(device_output.flatten(), golden_output.flatten())
pcc = jnp.min(pcc)
Expand Down
53 changes: 53 additions & 0 deletions tests/jax/graphs/test_MLP_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import jax
import jax.numpy as jnp
import pytest
from infra import ComparisonConfig, run_graph_test_with_random_inputs


@pytest.fixture
def comparison_config() -> ComparisonConfig:
config = ComparisonConfig()
config.pcc.allclose.atol = 0.03
config.pcc.allclose.rtol = 0.03
return config


@pytest.mark.parametrize(
["W1", "b1", "W2", "b2", "X", "y"],
[
[(784, 64), (32, 64), (64, 10), (32, 10), (32, 784), (32, 10)]
], # 32 samples, 784 features (28x28), 10 output classes
)
def test_nn_with_relu(W1, b1, W2, b2, X, y, comparison_config: ComparisonConfig):
def simple_nn(W1, b1, W2, b2, X, y):
def forward(W1, b1, W2, b2, X):
hidden = jnp.dot(X, W1) + b1
hidden = jnp.maximum(0, hidden)
output = jnp.dot(hidden, W2) + b2
return output

def loss(W1, b1, W2, b2, X, y):
output = forward(W1, b1, W2, b2, X)
return jnp.mean((output - y) ** 2)

@jax.jit
def update_params(W1, b1, W2, b2, X, y, lr=0.01):
grads = jax.grad(loss, argnums=(0, 1, 2, 3))(W1, b1, W2, b2, X, y)
W1 -= lr * grads[0]
b1 -= lr * grads[1]
W2 -= lr * grads[2]
b2 -= lr * grads[3]
return W1, b1, W2, b2, grads

for i in range(50):
W1, b1, W2, b2, grads = update_params(W1, b1, W2, b2, X, y, lr=0.01)

final_loss = loss(W1, b1, W2, b2, X, y)
return final_loss

run_graph_test_with_random_inputs(
simple_nn, [W1, b1, W2, b2, X, y], comparison_config=comparison_config
)
Loading