Skip to content

Commit

Permalink
add meta func for jagged 1d/2d to padded
Browse files Browse the repository at this point in the history
Summary: as titled, and also adding compile unit test to prevent failures in pt2

Differential Revision: D48427195

fbshipit-source-id: eaac0904497d3016ac6dc3eaaa1e877332d14a5d
  • Loading branch information
xw285cornell authored and facebook-github-bot committed Aug 19, 2023
1 parent a13b7bb commit e5fa909
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ class JaggedSliceOp : public torch::autograd::Function<JaggedSliceOp> {
Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
c10::SymIntArrayRef max_lengths,
const c10::SymIntArrayRef max_lengths,
const double padding_value) {
return JaggedToPaddedDenseOp::apply(
values, offsets, max_lengths, padding_value)[0];
Expand Down
42 changes: 41 additions & 1 deletion fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ Tensor jagged_to_padded_dense_forward_meta(
return at::empty_symint(padded_values_shape, values.options());
}

Tensor jagged_to_padded_dense_meta(
const Tensor& values,
const std::vector<Tensor>& offsets,
const c10::SymIntArrayRef max_lengths,
const double padding_value = 0) {
return jagged_to_padded_dense_forward_meta(
values, offsets, max_lengths, padding_value);
}

Tensor jagged_to_padded_dense_backward_meta(
const at::Tensor& grad_output,
const std::vector<Tensor>& offsets,
Expand Down Expand Up @@ -97,6 +106,13 @@ Tensor dense_to_jagged_forward_meta(
return values;
}

std::tuple<Tensor, std::vector<Tensor>> dense_to_jagged_meta(
const Tensor& dense,
const std::vector<Tensor>& offsets,
const c10::optional<at::SymInt>& total_L) {
return {dense_to_jagged_forward_meta(dense, offsets, total_L), offsets};
}

std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul_meta(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
Expand Down Expand Up @@ -193,18 +209,41 @@ Tensor jagged_softmax_backward_meta(
return at::empty_like(grad_output);
}

Tensor jagged_1d_to_dense_meta(
Tensor values,
Tensor offsets,
c10::SymInt max_L,
int64_t padding_value) {
return jagged_to_padded_dense_meta(values, {offsets}, {max_L}, padding_value);
}

Tensor jagged_2d_to_dense_meta(
Tensor values,
Tensor offsets,
c10::SymInt max_sequence_length) {
return jagged_to_padded_dense_meta(
values,
{offsets},
{max_sequence_length},
/*padding_value=*/0);
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl(
"jagged_to_padded_dense_forward",
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_forward_meta));
m.impl(
"jagged_to_padded_dense",
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_meta));
m.impl(
"jagged_to_padded_dense_backward",
TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_backward_meta));
m.impl(
"dense_to_jagged_forward",
TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_meta));
m.impl("dense_to_jagged", TORCH_FN(fbgemm_gpu::dense_to_jagged_meta));
m.impl(
"jagged_dense_dense_elementwise_add_jagged_output_forward",
TORCH_FN(
Expand Down Expand Up @@ -246,5 +285,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl(
"jagged_jagged_bmm",
TORCH_FN(fbgemm_gpu::jagged_jagged_bmm_forward_meta));
m.impl("jagged_1d_to_dense", TORCH_FN(fbgemm_gpu::jagged_1d_to_dense));
m.impl("jagged_1d_to_dense", TORCH_FN(fbgemm_gpu::jagged_1d_to_dense_meta));
m.impl("jagged_2d_to_dense", TORCH_FN(fbgemm_gpu::jagged_2d_to_dense_meta));
}
185 changes: 133 additions & 52 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:
)


# Converts lengths + values format to COO format
# [B], [N] -> [B, N'].
# pyre-ignore Missing return annotation [3]
def var_list_to_coo_1d(
lengths: torch.Tensor,
values: torch.Tensor,
N: int,
):
rows = lengths_to_segment_ids(lengths)
num_rows = lengths.size()[0]
# This does D&H sync
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
output_size = lengths.sum()
# This does D&H sync
cols = torch.ops.fbgemm.offsets_range(offsets, output_size)
indices = torch.stack([rows, cols])
dims = [num_rows, N]
# torch.sparse_coo_tensor is not supported by torch.fx, wrap it.
return torch.sparse_coo_tensor(
indices=indices,
values=values,
size=dims,
)


# Converts lengths + values format to COO format
# [B], [N, D] -> [B, N', D].
# pyre-ignore Missing return annotation [3]
Expand Down Expand Up @@ -279,6 +304,60 @@ def test_jagged_2d_to_dense_truncation(self) -> None:
output_values.backward(ref_output_values)
torch.testing.assert_close(expected_grad, values.grad)

@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
@given(
B=st.integers(min_value=2, max_value=128),
D=st.integers(min_value=2, max_value=128),
max_sequence_length=st.integers(min_value=1, max_value=200),
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]),
device_type=st.sampled_from(["cpu", "cuda"])
if gpu_available
else st.just("cpu"),
)
def test_jagged_2d_to_dense_dynamic_shape(
self,
B: int,
D: int,
max_sequence_length: int,
dtype: torch.dtype,
device_type: str,
) -> None:
D = D * 4
lengths_ = np.random.randint(low=0, high=max_sequence_length, size=B)
total_lengths = lengths_.sum()
lengths = torch.from_numpy(lengths_)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)

ref_values = torch.rand(total_lengths, D)
ref_output_values = var_list_to_coo(
lengths,
ref_values,
max_sequence_length,
D,
).to_dense()
ref_output_values = ref_output_values.to(dtype)

ref_values = ref_values.to(device_type)
values = ref_values.clone().to(dtype).detach().requires_grad_(True)
offsets = offsets.to(device_type)
ref_output_values = ref_output_values.to(device_type)
output_values = torch.compile(
torch.ops.fbgemm.jagged_2d_to_dense, dynamic=True, fullgraph=True
)(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
)
torch.testing.assert_close(ref_output_values, output_values)

output_values.backward(ref_output_values)
ref_values = ref_values.to(dtype)
torch.testing.assert_close(ref_values, values.grad)

@unittest.skipIf(*gpu_unavailable)
@settings(
verbosity=Verbosity.verbose,
Expand Down Expand Up @@ -359,41 +438,17 @@ def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:
lengths.long(),
)

# Converts lengths + values format to COO format
# [B], [N] -> [B, N'].
# pyre-ignore Missing return annotation [3]
def var_list_to_coo(
lengths: torch.Tensor,
values: torch.Tensor,
N: int,
):
rows = lengths_to_segment_ids(lengths)
num_rows = lengths.size()[0]
# This does D&H sync
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
output_size = lengths.sum()
# This does D&H sync
cols = torch.ops.fbgemm.offsets_range(offsets, output_size)
indices = torch.stack([rows, cols])
dims = [num_rows, N]
# torch.sparse_coo_tensor is not supported by torch.fx, wrap it.
return torch.sparse_coo_tensor(
indices=indices,
values=values,
size=dims,
)

lengths_ = np.random.randint(low=0, high=max_sequence_length, size=B)
total_lengths = lengths_.sum()
lengths = torch.from_numpy(lengths_)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)

ref_values = torch.randint(low=0, high=1000000000, size=(total_lengths,))
ref_values_mask = var_list_to_coo(
ref_values_mask = var_list_to_coo_1d(
lengths, torch.ones_like(ref_values), max_sequence_length
).to_dense()
ref_output_values = (
var_list_to_coo(
var_list_to_coo_1d(
lengths,
ref_values,
max_sequence_length,
Expand Down Expand Up @@ -457,6 +512,58 @@ def test_jagged_1d_to_dense_truncation(self) -> None:
)
torch.testing.assert_close(ref_output, output)

@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
@given(
B=st.integers(min_value=1, max_value=128),
max_sequence_length=st.integers(min_value=1, max_value=500),
padding_value=st.integers(min_value=-100000, max_value=100000),
device_type=st.sampled_from(["cpu", "cuda"])
if gpu_available
else st.just("cpu"),
)
def test_jagged_1d_to_dense_dynamic_shape(
self, B: int, max_sequence_length: int, padding_value: int, device_type: str
) -> None:
def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:
return torch.repeat_interleave(
torch._dim_arange(lengths, 0).long(),
lengths.long(),
)

lengths_ = np.random.randint(low=0, high=max_sequence_length, size=B)
total_lengths = lengths_.sum()
lengths = torch.from_numpy(lengths_)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)

ref_values = torch.randint(low=0, high=1000000000, size=(total_lengths,))
ref_values_mask = var_list_to_coo_1d(
lengths, torch.ones_like(ref_values), max_sequence_length
).to_dense()
ref_output_values = (
var_list_to_coo_1d(
lengths,
ref_values,
max_sequence_length,
).to_dense()
+ (1 - ref_values_mask) * torch.ones_like(ref_values_mask) * padding_value
)

ref_values = ref_values.to(device_type)
values = ref_values.clone().detach().requires_grad_(False)
offsets = offsets.to(device_type)
ref_output_values = ref_output_values.to(device_type)
output_values = torch.compile(torch.ops.fbgemm.jagged_1d_to_dense, dynamic=True, fullgraph=True)(
values=values,
offsets=offsets,
max_sequence_length=max_sequence_length,
padding_value=padding_value,
)
torch.testing.assert_close(ref_output_values, output_values)

@unittest.skipIf(*gpu_unavailable)
@settings(
verbosity=Verbosity.verbose,
Expand Down Expand Up @@ -486,30 +593,6 @@ def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor:
lengths.long(),
)

# Converts lengths + values format to COO format
# [B], [N] -> [B, N'].
# pyre-ignore Missing return annotation [3]
def var_list_to_coo(
lengths: torch.Tensor,
values: torch.Tensor,
N: int,
):
rows = lengths_to_segment_ids(lengths)
num_rows = lengths.size()[0]
# This does D&H sync
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
output_size = lengths.sum()
# This does D&H sync
cols = torch.ops.fbgemm.offsets_range(offsets, output_size)
indices = torch.stack([rows, cols])
dims = [num_rows, N]
# torch.sparse_coo_tensor is not supported by torch.fx, wrap it.
return torch.sparse_coo_tensor(
indices=indices,
values=values,
size=dims,
)

lengths_ = np.random.randint(low=0, high=max_sequence_length, size=B * T)
total_lengths = lengths_.sum()
lengths = torch.from_numpy(lengths_).to(device)
Expand Down Expand Up @@ -538,8 +621,6 @@ def var_list_to_coo(
ref_output_values, torch.cat(output_values_per_table)
)

# TODO: reuse code with var_list_to_coo and to_dense

def _to_padded_dense(
self,
values: torch.Tensor,
Expand Down

0 comments on commit e5fa909

Please sign in to comment.