From b4680645958636baee59c0b7878465ae96a930d9 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 7 Dec 2023 08:43:09 -0800 Subject: [PATCH] Add abstract impl fbgemm::dense_to_jagged (#2193) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2193 1. Remove the meta cpp function fbgemm::dense_to_jagged() and fbgemm::dense_to_jagged_forward() 2. Replace it with the Python abstract impl dense_to_jagged() Reviewed By: zou3519, yanboliang Differential Revision: D51216256 fbshipit-source-id: db6cc16d2eb91cff21ab1c4eb512af69338c1826 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 28 +++++++++++++++ .../jagged_tensor_ops_cpu.cpp | 3 ++ .../jagged_tensor_ops_meta.cpp | 26 -------------- fbgemm_gpu/test/failures_dict.json | 35 +------------------ fbgemm_gpu/test/jagged_tensor_ops_test.py | 4 --- 5 files changed, 32 insertions(+), 64 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 0979959a53..9e1bb468fb 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -371,3 +371,31 @@ def segment_sum_csr_abstract( output_size = csr_seg.numel() - 1 output = values.new_empty(output_size) return output + + +@impl_abstract("fbgemm::dense_to_jagged_forward") +def dense_to_jagged_forward( + dense: torch.Tensor, + offsets: List[torch.Tensor], + total_L: Optional[torch.SymInt] = None, +) -> torch.Tensor: + if not total_L: + total_L = torch.library.get_ctx().new_dynamic_size() + return dense.new_zeros( + total_L, + dense.size()[-1], + dtype=dense.dtype, + device=dense.device, + layout=dense.layout, + ) + + +@impl_abstract("fbgemm::dense_to_jagged") +def dense_to_jagged( + dense: torch.Tensor, + offsets: List[torch.Tensor], + total_L: Optional[torch.SymInt] = None, +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + if not total_L: + total_L = torch.library.get_ctx().new_dynamic_size() + return (dense_to_jagged_forward(dense, offsets, total_L), offsets) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index 5a1753b239..fb5ba53798 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1635,6 +1635,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { // SymInt is a new PyTorch 2.0 feature to support dynamic shape. See more // details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If // you find it doesn't compile, please pull the new PyTorch 2.0 code + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); m.def( "dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])", {PT2_COMPLIANT_TAG}); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp index b9e249cb90..fabcd6455b 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp @@ -92,28 +92,6 @@ Tensor jagged_dense_elementwise_add_meta( return at::empty_like(y); } -Tensor dense_to_jagged_forward_meta( - const Tensor& dense, - const std::vector& offsets, - c10::optional total_L) { - auto dense_values = dense; - at::SymInt D = dense_values.sym_size(-1); - TORCH_CHECK_NOT_IMPLEMENTED( - total_L.has_value(), "total_L is required for meta backend"); - auto& total_L_computed = total_L.value(); - auto values = at::zeros_symint({total_L_computed, D}, dense_values.options()); - - TORCH_CHECK(values.is_meta()); - return values; -} - -std::tuple> dense_to_jagged_meta( - const Tensor& dense, - const std::vector& offsets, - c10::optional total_L) { - return {dense_to_jagged_forward_meta(dense, offsets, total_L), offsets}; -} - std::tuple> jagged_dense_elementwise_mul_meta( const Tensor& x_values, const std::vector& x_offsets, @@ -241,10 +219,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl( "jagged_to_padded_dense_backward", TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_backward_meta)); - m.impl( - "dense_to_jagged_forward", - TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_meta)); - m.impl("dense_to_jagged", TORCH_FN(fbgemm_gpu::dense_to_jagged_meta)); m.impl( "jagged_dense_dense_elementwise_add_jagged_output_forward", TORCH_FN( diff --git a/fbgemm_gpu/test/failures_dict.json b/fbgemm_gpu/test/failures_dict.json index 43efa968c9..c754ee4c7d 100644 --- a/fbgemm_gpu/test/failures_dict.json +++ b/fbgemm_gpu/test/failures_dict.json @@ -115,40 +115,7 @@ "status": "xfail" } }, - "fbgemm::dense_to_jagged": { - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_meta_backend": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_opt": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_opt_large_batch": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_meta_backend": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_opt": { - "comment": "", - "status": "xfail" - }, - "JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_opt_large_batch": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::dense_to_jagged": {}, "fbgemm::expand_into_jagged_permute": {}, "fbgemm::generic_histogram_binning_calibration_by_feature": { "SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": { diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index 8465490282..1ff06206ff 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -132,10 +132,6 @@ def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: List[int]) -> List[int]: # skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. additional_decorators: Dict[str, List[Callable]] = { - "test_pt2_compliant_tag_fbgemm_dense_to_jagged": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ], "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ # This operator has been grandfathered in. We need to fix this test failure. unittest.expectedFailure,