-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dose data_prefetcher() really speed up training? #304
Comments
It really depends on your model. It's hard to tell the impact of non-overlapped dataloading, and if prefetching successfully overlaps, without looking at a visual profile. For data prefetching to overlap, the source batch on the CPU must be pinned (in other words, the dataloader should receive the argument The prefetcher is not an "official" piece of Apex, it was just a cool piece of code that we found useful for Imagenet. I think it gives us a minor (maybe 5%) speedup for Resnet50, but that's dependent on batch size and whether we're single-GPU or multi-GPU. |
Yes, |
@mcarilli @ngimel Thanks for your reply. I did some tests and again data_prefetcher() dosen't speed up. I find something I want to share with you.
But I think pytorch cuda impeletation that each device could execute only one operation at the same time. Every operations queue in line, though there is a asynchronous option, My batch is a tuple with 3 elements. I send this to gpu with I did time logging for each computaion stage and I found that both two stages cost much more time than before. So I think when data sent to GPU asynchronously, but it dosen't actually send data. When the computation need this data, then it sends data to GPU memory. So that's why my two computation time both increased. I think that's why All above are my tests and experiences summary. Maybe it's not right. But I found something in pytorch official document as follow. pytorch CUDA asynchronous-execution
Maybe I have some misunderstanding. BTW, I have another problem. When I run my training in two nodes with 16 GPUs on each nodes instead of one node with 16 GPUs. The time sending data to GPU nearly doubled! I am very confused. Thanks again. |
By default, Pytorch enqueues all operations involving the gpu (kernel launches, cpu->gpu memcopies, and gpu->cpu memcopies) on the same stream (the "default stream"). Operations on the same stream are serialized and can never overlap. For two operations to overlap, they must be in different streams. Also, for cpu->gpu and gpu->cpu memcopies in particular, the CPU-side memory must be pinned, otherwise the memcopy will be blocking with respect to all streams. The forward pass is performed in the default stream. Therefore, for a cpu->gpu prefetch (of the next iteration's data) to overlap with the forward pass of the current iteration
Our data_prefetcher satisfies both of these requirements. For overlapped prefetching, supplying pin_memory=True to the dataloader is always required (to satisfy 1.). If your data batch is a tuple of Tensors, then supplying I'm not sure why the dataloading time doubles for a 2-node run. Are the files of your dataset only on one node's hard drive, and are they being accessed from the other node via a shared network drive or something? This would mean that one node is slower than another. For best results, the full dataset's files should be present on the hard drive of both nodes. |
@mcarilli I have one question about training process. In one epoch, I measured four part time consuming. Part 1: for _, img, tgt in enumerate(dataloader) which costs 4.18 sec |
I'm not sure how you're obtaining your timings, but cpu timings can be deceptive because the gpu operates asynchronously. It's hard to tell what's really causing your bottleneck without looking at a visual profile. You can try profiling using the example I posted here. |
@mcarilli Thank you very much ! I have used Nvidia Visual Profiler to analyse my code under your guidance. I think my problem in dataloader may be related with initialization. Now the time in preparing data is very small. Next step I want to optimize the time transfer data from CPU to GPU. However, when I used prefetcher , I met one question: the dataloader returned NoneType instead of images and targets. I also noticed that images and targets was true before it entered record_stream() ( memory -> CUDA memory). In your example, I found your annotations which represented if record_stream was not useful, we could also allocate new space in CUDA memory. But I could not make it successful. The following is my prefetcher code. What are different from your example are that in my case, images are organized by list ([tensor, tensor, tensor, tensor]) and GTs are in dictionary type. Sorry to bother you again .
|
Using the existing uncommented code (aka record_stream) is the preferred approach. The above code looks like it should work, and properly handle the end of the epoch where next_features and next_targets are None. What exactly is the error you see? Also, I said this before but that was a while ago so it's worth reminding: the prefetcher will not enable overlap unless you also supply |
@mcarilli Sorry for my late reply. I have used pin_memory=True in Pytorch dataloader. The code did not report an obvious error. However, after record_stream(), the features have become NoneType , which caused failures.
After running the above code, I get info as following:
|
I try to split the features list and targets dictionary into single tensor. The code can run normally. However, it seems very slow when getting batch data compared with traditional way. |
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std) why don't you put |
from torch.utils.data.dataloader import _DataLoaderIter
class CUDAPrefetcher(_DataLoaderIter):
def __init__(self, loader, device=None, priority=0):
if not torch.cuda.is_available():
raise Exception("Only CUDA")
super(CUDAPrefetcher, self).__init__(loader)
self.device = device
self.stream = torch.cuda.Stream(device=device, priority=priority)
self.last = None
def __next__(self):
torch.cuda.default_stream(device=self.device).wait_stream(stream=self.stream)
result = self.last
if result is None:
result = super(CUDAPrefetcher, self).__next__() # may raise StopIteration
for x, d in enumerate(result):
result[x] = d.to(device=self.device, non_blocking=False)
try:
self.last = super(CUDAPrefetcher, self).__next__()
with torch.cuda.stream(stream=self.stream):
for x, d in enumerate(self.last):
self.last[x] = d.to(device=self.device, non_blocking=True)
except StopIteration:
self.last = None
return result for ip, op in CUDAPrefetcher(DataLoader(TensorDataset(inputs, outputs), batch_size=int(20000), pin_memory=True), device=dev, priority=-10): This is my attempt to solve the problem, but there is no acceleration. |
@Lausannen , have you fixed this bugs properly? |
@youngfly11 Not yet. Maybe https://zhuanlan.zhihu.com/p/80695364 can help you. |
I have the same issue. Why is the record stream outputing NoneType??? |
The solution that I found was manually keeping every variable as the attribute of the data_prefecther, self.attribute_name = attribute |
Hi, How do you save this problem? Can you show the code snippet here for clear? Thanks |
hello,I find the bug that it is possible to normalize the ‘self.next_input ’ before another stream finish to send data to GPU memory!
|
After do some deep profiling, I think the problem is due to the
Notices that "until all current work queued on stream are complete.". It means the tensor memory will not be reused if current stream is stilling working. But for some reason, it blocks the default stream until the tensor's stream is completed. Here is the timeline: |
@DelightRun (2) from the data_prefetcher code, it will prefetch two batches before the first iteration I think (something like the fifth iteration's data overlap with the batch processing of the third iteration). I am not sure why prefetch two here. Appreciate any response. |
For those getting # xaf.record_stream(torch.cuda.current_stream()) always returns None, so the list will be full of None's
features = [xaf.record_stream(torch.cuda.current_stream()) for xaf in features] you should instead do something like this: features_list = []
for k, v in features.items():
v.record_stream(torch.cuda.current_stream())
features_list.append(v) |
I used your Python code https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py#L256
My code is
while True
I think this
data_prefetcher
could speed up training, because there is another stream sending data to GPU memory when model is running in the GPU. So there is a very small gap between two iteration.However, this trick does not work for me. So please help me
data_prefetcher
is really for speed up?The text was updated successfully, but these errors were encountered: