Skip to content

Commit

Permalink
Add some more validation checks for torch.linalg.eigh and torch.compi…
Browse files Browse the repository at this point in the history
…le (#1580)

* Add some more validation checks for torch.linalg.eigh and torch.compile

* Update test

* Also update smoke_test.py

* Fix lint
  • Loading branch information
huydhn authored Nov 16, 2023
1 parent c6cbe77 commit 4c7fa06
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
6 changes: 6 additions & 0 deletions check_binary.sh
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE
echo "Test that linalg works"
python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))"

echo "Test that linalg.eigh works"
python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(torch.mm(x.t(), x)))"

echo "Checking that basic torch.compile works"
python ${TEST_CODE_DIR}/torch_compile_smoke.py

popd
fi # if libtorch
fi # if cuda
Expand Down
3 changes: 3 additions & 0 deletions test/smoke_test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def smoke_test_linalg() -> None:
A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype)
torch.linalg.svd(A)

A = torch.rand(3, 3, device="cuda")
L, Q = torch.linalg.eigh(torch.mm(A.t(), A))


def smoke_test_compile() -> None:
supported_dtypes = [torch.float16, torch.float32, torch.float64]
Expand Down
12 changes: 12 additions & 0 deletions test_example_code/torch_compile_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch


def foo(x: torch.Tensor) -> torch.Tensor:
return torch.sin(x) + torch.cos(x)


if __name__ == "__main__":
x = torch.rand(3, 3, device="cuda")
x_eager = foo(x)
x_pt2 = torch.compile(foo)(x)
print(torch.allclose(x_eager, x_pt2))

0 comments on commit 4c7fa06

Please sign in to comment.