Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make as_strided_copy materialize a new tensor with index. #6624

Merged
merged 7 commits into from
Mar 5, 2024

Conversation

ysiraichi
Copy link
Collaborator

Fix: #5835

This PR implements arbitrary as_strided function by decomposing it into slicing+indexing. In summary, we slice the base tensor for complying with the given storage_offset, and then index a flattened version of the tensor, gathering the desired elements, based on the given size and strides. (more explanation in the code).

cc @miladm @JackCaoG @lezcano

@ysiraichi

This comment was marked as outdated.

@lezcano
Copy link
Collaborator

lezcano commented Feb 27, 2024

Haven't looked at the code in depth, but this sounds plausible. Will review tomorrow.

@bdhirsh we could use this to functionalize as_strided. With this same trick we can even write to the relevant view via index_put.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great algorithm! Thank you @ysiraichi!

I think it is correct modulo potential corner-cases that may pop up.

torch_xla/csrc/aten_xla_type.cpp Show resolved Hide resolved
Comment on lines 733 to 736
if (storage_offset.has_value() && *storage_offset > 0) {
// If there's a storage_offset, slice this tensor, first.
tensor = slice_copy(tensor, 0, *storage_offset, c10::nullopt, 1);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do this, or simply add storage_offset to index_tensor at the end.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I kind of thought about it for a second and, for some reason, decided it wouldn't be correct. But, on second thoughts, it does make sense.

Comment on lines 730 to 731
// Flatten the tensor, so that it's easier to gather its elements.
tensor = view_copy_symint(tensor, {tensor.numel()});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than this flattening + index, you can simply use torch.take.

@ysiraichi ysiraichi marked this pull request as ready for review February 28, 2024 14:26
@ysiraichi ysiraichi requested a review from JackCaoG February 28, 2024 14:26
@ysiraichi ysiraichi force-pushed the ysiraichi/fix-overlapped-asstride branch from bc6409c to 7afeb56 Compare February 28, 2024 14:32
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic lgtm

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-overlapped-asstride branch from 7afeb56 to 9ea4600 Compare February 29, 2024 13:48
@ysiraichi
Copy link
Collaborator Author

ysiraichi commented Feb 29, 2024

@JackCaoG Could you take a look at this PR whenever you have some time?

@ysiraichi
Copy link
Collaborator Author

I believe these export tests are unrelated.

@JackCaoG @zpcore @frgossen @vanbasten23 @cota @golechwierowicz
Have you seen these, before?

@JackCaoG
Copy link
Collaborator

not really, @lsy323 do you know what this unbounded export test is doing?

@ysiraichi ysiraichi force-pushed the ysiraichi/fix-overlapped-asstride branch from 4b66320 to 8bb2aea Compare March 1, 2024 15:32
@ysiraichi
Copy link
Collaborator Author

@JackCaoG @alanwaketan Could you take a look at this PR when you have some time?

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, and only have one question.

// [[[0]]]
//
std::vector<int64_t> view_shape(dim, 1);
auto index_tensor =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is computed by cpu eager in the following code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Given size, stride, and offset argument spec, we compute ahead of time the correct indices for materializing the tensor. No need for computing at runtime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@alanwaketan alanwaketan merged commit 3abc21d into master Mar 5, 2024
18 checks passed
@lsy323
Copy link
Collaborator

lsy323 commented Mar 7, 2024

Hi @ysiraichi, I found this PR causes some performance regression on TPU v4-8 (Also can be repro in v3-8). The regression can be reproduced by running the following command:

python test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256

When I'm at b8864fc5a5ba91640904b075d69aee0c5f9ceff4, the speed is:

Epoch 1 train begin 21:58:52
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:00:16
| Training Device=xla:0/3 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:00:16
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:00:16
| Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:00:16
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:02:37
| Training Device=xla:0/1 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:02:37
| Training Device=xla:0/3 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:02:37
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:02:37
| Training Device=xla:0/1 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:03:09
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:03:09
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:03:09
| Training Device=xla:0/3 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:03:09

When I'm at 3abc21df7aaa176804d3cbbc60f5078d579831b7, it's much slower.

Epoch 1 train begin 22:11:06
| Training Device=xla:0/3 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:17:32
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:17:32
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:17:32
| Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.89059 Rate=0.00 GlobalRate=0.00 Time=22:17:32
| Training Device=xla:0/3 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:31:14
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:31:14
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:31:14
| Training Device=xla:0/1 Epoch=1 Step=200 Loss=0.04890 Rate=0.00 GlobalRate=0.00 Time=22:31:14
| Training Device=xla:0/1 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:37:41
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:37:41
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:37:41
| Training Device=xla:0/3 Epoch=1 Step=400 Loss=0.01260 Rate=0.00 GlobalRate=0.00 Time=22:37:41

I'm reverting this PR for now, since we are close to the 2.3 branch cut date (March 11th).

Could you please re-land the PR after the perf regression is resolved? Thanks a lot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

No support for overlapped tensors.
5 participants