Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add abstract impl fbgemm::dense_to_jagged #2193

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,31 @@ def segment_sum_csr_abstract(
output_size = csr_seg.numel() - 1
output = values.new_empty(output_size)
return output


@impl_abstract("fbgemm::dense_to_jagged_forward")
def dense_to_jagged_forward(
dense: torch.Tensor,
offsets: List[torch.Tensor],
total_L: Optional[torch.SymInt] = None,
) -> torch.Tensor:
if not total_L:
total_L = torch.library.get_ctx().new_dynamic_size()
return dense.new_zeros(
total_L,
dense.size()[-1],
dtype=dense.dtype,
device=dense.device,
layout=dense.layout,
)


@impl_abstract("fbgemm::dense_to_jagged")
def dense_to_jagged(
dense: torch.Tensor,
offsets: List[torch.Tensor],
total_L: Optional[torch.SymInt] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
if not total_L:
total_L = torch.library.get_ctx().new_dynamic_size()
return (dense_to_jagged_forward(dense, offsets, total_L), offsets)
3 changes: 3 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
// SymInt is a new PyTorch 2.0 feature to support dynamic shape. See more
// details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If
// you find it doesn't compile, please pull the new PyTorch 2.0 code
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
m.def(
"dense_to_jagged(Tensor dense, Tensor[] x_offsets, SymInt? total_L=None) -> (Tensor, Tensor[])",
{PT2_COMPLIANT_TAG});
Expand Down
26 changes: 0 additions & 26 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,6 @@ Tensor jagged_dense_elementwise_add_meta(
return at::empty_like(y);
}

Tensor dense_to_jagged_forward_meta(
const Tensor& dense,
const std::vector<Tensor>& offsets,
c10::optional<at::SymInt> total_L) {
auto dense_values = dense;
at::SymInt D = dense_values.sym_size(-1);
TORCH_CHECK_NOT_IMPLEMENTED(
total_L.has_value(), "total_L is required for meta backend");
auto& total_L_computed = total_L.value();
auto values = at::zeros_symint({total_L_computed, D}, dense_values.options());

TORCH_CHECK(values.is_meta());
return values;
}

std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged_meta(
const Tensor& dense,
const std::vector<Tensor>& offsets,
c10::optional<at::SymInt> total_L) {
return {dense_to_jagged_forward_meta(dense, offsets, total_L), offsets};
}

std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul_meta(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
Expand Down Expand Up @@ -241,10 +219,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl(
"jagged_to_padded_dense_backward",
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_backward_meta));
m.impl(
"dense_to_jagged_forward",
TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_meta));
m.impl("dense_to_jagged", TORCH_FN(fbgemm_gpu::dense_to_jagged_meta));
m.impl(
"jagged_dense_dense_elementwise_add_jagged_output_forward",
TORCH_FN(
Expand Down
35 changes: 1 addition & 34 deletions fbgemm_gpu/test/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,7 @@
"status": "xfail"
}
},
"fbgemm::dense_to_jagged": {
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_meta_backend": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_opt": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_aot_dispatch_dynamic__test_dense_to_jagged_opt_large_batch": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_meta_backend": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_opt": {
"comment": "",
"status": "xfail"
},
"JaggedTensorOpsTest.test_faketensor__test_dense_to_jagged_opt_large_batch": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::dense_to_jagged": {},
"fbgemm::expand_into_jagged_permute": {},
"fbgemm::generic_histogram_binning_calibration_by_feature": {
"SparseOpsTest.test_aot_dispatch_dynamic__test_generic_histogram_binning_calibration_by_feature": {
Expand Down
4 changes: 0 additions & 4 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,6 @@ def hash_size_cumsum_to_offsets(hash_size_cum_sum_list: List[int]) -> List[int]:
# skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json
# pyre-ignore[24]: Generic type `Callable` expects 2 type parameters.
additional_decorators: Dict[str, List[Callable]] = {
"test_pt2_compliant_tag_fbgemm_dense_to_jagged": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
],
"test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
Expand Down
Loading