Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer.step() takes too long in XLA/TPU #7923

Open
bytetriper opened this issue Aug 29, 2024 · 12 comments
Open

optimizer.step() takes too long in XLA/TPU #7923

bytetriper opened this issue Aug 29, 2024 · 12 comments
Assignees

Comments

@bytetriper
Copy link

🐛 Bug

Common optimizer like Adam/AdamW takes too long in optimizer.step() for small models. I tested a small ViT with 5.8M parameters and torch.optim.AdamW takes ~.2s for a single step.

To Reproduce

Below is a minimal example that can be run by python test.py. It trains a 5.8M ViT on fake image data with a local batch size of 8. The average training speed measured on a TPU v4-8 is 4.2it/s, which is much lower than expected.

from functools import partial
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch_xla
import torch_xla.amp
import torch_xla.amp.syncfree
import torch_xla.core
from torchvision.models import VisionTransformer
import torch_xla.core.xla_model as xm
import itertools
from torch_xla.distributed.parallel_loader import ParallelLoader
from torch_xla.core.xla_model import collective_broadcast
from tqdm import tqdm
from torch_xla.distributed import xla_backend
from torch_xla.distributed import xla_multiprocessing as xmp
import torch_xla.runtime as xr
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
import os
from contextlib import nullcontext
cache_compile_path = './xla_compile/test_mp_opt' 
os.makedirs(cache_compile_path, exist_ok=True)
xr.initialize_cache(cache_compile_path, readonly=False)
import torch_xla.debug.profiler as xp
import torch_xla.debug.metrics as met
from torch_xla.debug.profiler import trace
import torch.autograd.profiler as profiler
profile_log_path = './profile' 
tracing = True
os.makedirs(profile_log_path, exist_ok=True) 
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = int(length)
        self.size = size
    def __getitem__(self, index):
        return torch.randn(*self.size)
    def __len__(self):
        return self.len
    def _collate_fn(self, batch):
        return torch.stack(batch)
### below are some vision models with seq length 64 ###
def model_S(): # 5.8M
    return VisionTransformer(
        image_size = 256,
        patch_size = 32,
        num_layers= 6,
        num_heads = 16,
        hidden_dim = 256,
        mlp_dim = 1024
    )
def model_M(): # 
    return VisionTransformer(
        image_size = 256,
        patch_size = 32,
        num_layers= 12,
        num_heads = 16,
        hidden_dim = 256,
        mlp_dim = 1024
    )
def model_L(): # 88.24M
    return VisionTransformer(
        image_size = 256,
        patch_size = 32,
        num_layers= 12,
        num_heads = 16,
        hidden_dim = 768,
        mlp_dim = 3072
    )
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
def broadcast_master_param(model: torch.nn.Module) -> None:
  """
  Broadcast the model parameters from master process to other processes
  """
  parameters_and_buffers = list(
      itertools.chain(model.parameters(), model.buffers()))
  collective_broadcast(parameters_and_buffers, pin_layout=True)
  xm.mark_step()
def main(rank:int = None):
    print(f'xm.ordinal: {xm.get_ordinal()}')
    device = xm.xla_device() # current device
    print(f'[!]device: {device}')
    global profile_log_path , tracing
    world_size = xm.xrt_world_size() # world size
    rank = xm.get_ordinal() # rank
    cpu_model = model_S()
    model = cpu_model.to(device)
    # broadcast the model to all devices
    xm.rendezvous('init')
    xm.master_print(f'[!]broadcasting model parameters...')
    broadcast_master_param(model)
    param_count = count_parameters(model)
    xm.master_print(f'[!]model broadcasted, total trainable parameters: {param_count/1e6:.2f}M')
    xm.mark_step()
    # create optimizer
    optimizer_cls_dict = {
        'naive': optim.AdamW,
        'syncfree': torch_xla.amp.syncfree.AdamW,
        'zero': partial(ZeroRedundancyOptimizer, optimizer_class=optim.AdamW),
    }
    choice = 'naive' # change to see other optimizers' performance
    optimizer = optimizer_cls_dict[choice](model.parameters(), lr=1e-4)
    #optimizer = ZeroRedundancyOptimizer(
    #    model.parameters(),
    #    optimizer_class=optim.AdamW,
    #    lr=1e-4,
    #)
    xm.master_print(f'[!]optimizer created')
    xm.mark_step()
    # create data loader
    dataset = RandomDataset((3, 256, 256), 1e4) # 1e4 is large enough to see the performance
    sampler_train = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    loader_train = DataLoader(
        dataset,
        batch_size=8, # small local batch size 8
        sampler=sampler_train,
        num_workers=4,
        drop_last=True,
        collate_fn=dataset._collate_fn
    )
    xm.mark_step()
    # start training
    server = xp.start_server(9012, only_on_master=True) if tracing else None
    model.train()
    loss_fct = nn.CrossEntropyLoss()
    for epoch in range(1):
        met.clear_all()
        loader = ParallelLoader(loader_train, [device]).per_device_loader(device)       
        tbar = tqdm(loader, total=len(loader), desc=f'[!]epoch {epoch}', disable=not xm.is_master_ordinal())
        for i, data in enumerate(tbar):
            if i == 0 and xm.is_master_ordinal() and tracing:
                xp.trace_detached('localhost:9012', profile_log_path, duration_ms = 100000) # trace 100s, change the duration if needed
            with xp.StepTrace('train_cls',step_num=i) if tracing else nullcontext():
                with xp.Trace('build_graph') if tracing else nullcontext():
                    data = data.to(device)
                    labels = torch.zeros(data.size(0), dtype=torch.long).to(data.device) # pseudo labels
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fct(output, labels)
                with xp.Trace('backward') if tracing else nullcontext():
                    loss.backward()
                with xp.Trace('reduce_grad') if tracing else nullcontext():
                    if choice != 'zero': # zero optimizer has its own reduce_gradients
                        xm.reduce_gradients(optimizer)
                with xp.Trace('step') if tracing else nullcontext():
                    optimizer.step()
                if not tracing: # if tracing StepTrace will do the mark step
                    xm.mark_step()
            tbar.set_postfix({'loss': loss.item()}) # fetching is not good but it won't hurt much
    xm.mark_step()
    xm.master_print(met.metrics_report())
    xm.master_print(f'[!]training finished')  
if __name__ == '__main__':
    xmp.spawn(main, args=(), start_method='fork')
    #main()

A one step trace screenshot shows the major time is spend on optimizer.step():

step

Expected behavior

A much faster speed up for optimizer.step

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.3.0

Additional context

As suggested by @JackCaoG setting XLA_DISABLE_FUNCTIONALIZATION=1 mitigates the problem with ~5 time speed up. Would love to see more detailed info about this.

Also only _single_tensor_adamw is supported in XLA/TPU now. Any plan to support optimization like fused/foreach/CPUAdamW ?

@JackCaoG
Copy link
Collaborator

@bdhirsh FYI, this is what we discussed last week.

There are 2 issues I think

  1. lerp, addcdiv, addcmul all get decompose althrough pytorch/xla already has direct lowering
  2. each torch ops will calls the python meta function which is slow

Pasting some of the chats between Brain and me to add more context

one maybe lower-blast-radius fix would be to agree that removing the meta compute for views has higher risk, but removing it for inplace ops is unlikely to matter in practice and, the inplace ops are where the slowdown is coming for you anyway

aka just tweak that line to never run for inplace ops in the codegen.

you can basically tweak this line of the codegen: https://github.com/pytorch/pytorch/blob/main/torchgen/gen_functionalization_type.py#L718

[pytorch/torchgen/gen_functionalization_type.py at main · pytorch/pytorch](https://github.com/pytorch/pytorch/blob/main/torchgen/gen_functionalization_type.py#L718)

so that the if statement gets generated to always return false for xla tensors

so you can do something like this:
    check_any_mutated_tensors_are_xla = " || ".join(
        ["false"]
        + [
            f"{a}.device().type() == c10::DeviceType::XLA"
            for a in mutated_tensor_names
        ]
    )

and then update that if statement in the codegen to:
      if (!({str(check_any_mutated_tensors_are_xla)}) && {str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{

so that the overall conditional always returns False when there are any XLA inputs to a mutable-op functioanlization kernel

@JackCaoG
Copy link
Collaborator

Another thing you could do for now is to use a larger batch size. Your issue here is you are tracing bound, but increasing the batch size will increase the device execution time while keep the tracing time constant. You will be able to utilize the device better.

@bytetriper
Copy link
Author

bytetriper commented Aug 29, 2024

Thanks for the response! Yes increasing batch size would help a lot and I usually deal with large batch size. But sometimes for large pod I have to use small batch per core & that's when I found the problem.

@ManfeiBai
Copy link
Collaborator

Hi, @JackCaoG, is that ok to assign this ticket to you?

@ysiraichi
Copy link
Collaborator

@bytetriper Could you check if the issue still happens using this branch?

@miladm
Copy link
Collaborator

miladm commented Sep 16, 2024

@bytetriper have you run this code on XLA:GPU? in case you did, do you see a similar performance issue there?

@ysiraichi
Copy link
Collaborator

I have tried running resnet18 with AdamW optimizer, passing foreach=True parameter, using XLA:CUDA. However I haven't been able to reproduce it.

That said, I did notice that after my branch, the execution of _foreach_{lerp,addcdiv,addcmul}_ meta functions, which @JackCaoG mentioned being one of the issues, was faster than without it:

  • lerp: 1,550ms → 140ms
  • addcdiv: 640ms → 350ms
  • addcmul: 620ms → 300ms

Note: in this case, _foreach_<op> meta functions will call <op> many times. Therefore, reducing <op> latency improves _foreach_<op> latency.

@JackCaoG Could you help checking whether my branch fixes this issue?

@JackCaoG
Copy link
Collaborator

I will try to repo on my end, but I am busy these two days.

@bytetriper
Copy link
Author

@bytetriper have you run this code on XLA:GPU? in case you did, do you see a similar performance issue there?

Hi. No I haven't. But I think this is a TPU issue. XLA:TPU does not support many fancy AdamW tricks like foreach/fused but only a naive optimizer, which in my opinion contributes a lot to this issue

@ysiraichi
Copy link
Collaborator

@bytetriper Still, in the branch I mention above, I managed to reduce the latency of a few AdamW operations. Could you check if that solves the issue (even on TPU)?

@miladm
Copy link
Collaborator

miladm commented Sep 23, 2024

@JackCaoG do we plan to land this PR in 2.5?

@JackCaoG
Copy link
Collaborator

I think the fix should stay in nightly.

ysiraichi added a commit to pytorch/pytorch that referenced this issue Sep 28, 2024
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their
respective inplace versions).

These functions only had refs implementations, which was being the root cause of a
significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA
backend. Running the meta functions resulted in the following improvements:

- `lerp` calls: 1,550ms to 140ms (10x)
- `addcdiv` calls: 640ms to 350ms (1.8x)
- `addcmul` calls: 620ms to 300ms (2.05x)

[1]: pytorch/xla#7923

ghstack-source-id: f08891d8ecfd949a298ab6603534297caaf9deaf
Pull Request resolved: #136909
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants
@miladm @ysiraichi @ManfeiBai @bytetriper @JackCaoG and others