Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten.searchsorted.Tensor meta kernel #101637

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def run_meta_crossref(
torch.nn.functional.one_hot : {i64},
torch.nn.functional.pdist : {f64, f32},
torch._segment_reduce : {f64, f16, bf16, f32},
torch.searchsorted : {f64, i32, i64, f16, u8, i16, bf16, i8, f32},
torch.cholesky : {f64, f32, c128, c64},
torch.cholesky_inverse : {f64, f32, c128, c64},
torch.cholesky_solve : {f64, f32, c128, c64},
Expand Down Expand Up @@ -861,8 +860,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.multinomial.out : {bf16, f32, f64},
aten.nll_loss2d_forward.default : {bf16, f32, f64},
aten.rrelu_with_noise.default : {bf16, f32, f64},
aten.searchsorted.Tensor : {f16, i8, f64, i64, bf16, f32, i32, i16, u8},
aten.searchsorted.Tensor_out : {f16, i8, f64, i64, bf16, f32, i32, i16, u8},
aten.segment_reduce.default : {bf16, f32, f16, f64},
aten.unique_consecutive.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
aten.unique_dim.default : {i8, f64, i64, f16, bf16, f32, i32, b8, i16, u8},
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,6 @@ def f(a, b, c, d, e):
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('roll', ''), # Tensors of type TensorImpl do not have numel
xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ...
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition
Expand Down
9 changes: 9 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3545,6 +3545,15 @@ def t_(self):
return transpose_(self, 0, 0 if ndims < 2 else 1)


@register_meta([aten.searchsorted.Tensor, aten.searchsorted.Tensor_out])
@out_wrapper()
def meta_searchsorted(
sorted_sequence, self, *, out_int32=False, right=False, side=None, sorter=None
):
dtype = torch.int32 if out_int32 else torch.int64
return torch.empty_like(self, dtype=dtype).contiguous()


# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs
Expand Down