Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Sep 30, 2024
1 parent bb3eb62 commit 3df3e71
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 78 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
target : [ tensorflow ]
target : [ jax, tensorflow ]
steps:
- name: Checkout Ivy 🛎
uses: actions/checkout@v3
Expand Down
79 changes: 2 additions & 77 deletions ivy_tests/test_integrations/test_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,29 +113,6 @@ def test_xyz_to_rgb(target_framework, backend_compile):
)


def test_raw_to_rgb_2x2_downscaled(target_framework, backend_compile):
trace_args = (
torch.rand(1, 1, 4, 6),
kornia.color.CFA.RG,
)
trace_kwargs = {}
test_args = (
torch.rand(5, 1, 4, 6),
kornia.color.CFA.RG,
)
test_kwargs = {}
_test_function(
"kornia.color.raw_to_rgb_2x2_downscaled",
trace_args,
trace_kwargs,
test_args,
test_kwargs,
target_framework,
backend_compile=backend_compile,
tolerance=1e-3,
)


def test_sepia(target_framework, backend_compile):
trace_args = (torch.rand(1, 3, 4, 4),)
trace_kwargs = {
Expand Down Expand Up @@ -1242,13 +1219,13 @@ def test_unproject_meshgrid(target_framework, backend_compile):
4,
torch.eye(3),
)
trace_kwargs = {"normalize_points": False, "device": "cpu", "dtype": torch.float32}
trace_kwargs = {"normalize_points": False, "device": "cpu"}
test_args = (
5,
5,
torch.eye(3),
)
test_kwargs = {"normalize_points": False, "device": "cpu", "dtype": torch.float32}
test_kwargs = {"normalize_points": False, "device": "cpu"}
_test_function(
"kornia.geometry.depth.unproject_meshgrid",
trace_args,
Expand Down Expand Up @@ -1550,58 +1527,6 @@ def test_determinant_to_polynomial(target_framework, backend_compile):
)


def test_spatial_soft_argmax2d(target_framework, backend_compile):
trace_args = (torch.rand(1, 1, 5, 5),)
trace_kwargs = {
"temperature": torch.tensor(1.0),
"normalized_coordinates": True,
}
test_args = (torch.rand(10, 1, 5, 5),)
test_kwargs = {
"temperature": torch.tensor(0.5),
"normalized_coordinates": True,
}
_test_function(
"kornia.geometry.subpix.spatial_soft_argmax2d",
trace_args,
trace_kwargs,
test_args,
test_kwargs,
target_framework,
backend_compile=backend_compile,
tolerance=1e-3,
)


def test_render_gaussian2d(target_framework, backend_compile):
trace_args = (
torch.tensor([[1.0, 1.0]]),
torch.tensor([[1.0, 1.0]]),
(5, 5),
)
trace_kwargs = {
"normalized_coordinates": False,
}
test_args = (
torch.tensor([[2.0, 2.0]]),
torch.tensor([[0.5, 0.5]]),
(10, 10),
)
test_kwargs = {
"normalized_coordinates": False,
}
_test_function(
"kornia.geometry.subpix.render_gaussian2d",
trace_args,
trace_kwargs,
test_args,
test_kwargs,
target_framework,
backend_compile=backend_compile,
tolerance=1e-3,
)


def test_nms3d(target_framework, backend_compile):
trace_args = (
torch.rand(1, 1, 5, 5, 5),
Expand Down

0 comments on commit 3df3e71

Please sign in to comment.