Skip to content

Commit

Permalink
Fix high memory usage in autotest (#7988)
Browse files Browse the repository at this point in the history
* Fix high memory usage in autotest

* restore conv3d test

Signed-off-by: daquexian <daquexian566@gmail.com>

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
daquexian and mergify[bot] authored Apr 14, 2022
1 parent a5aa164 commit ae13b04
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
15 changes: 7 additions & 8 deletions python/oneflow/test/modules/test_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@

@flow.unittest.skip_unless_1n1d()
class TestConv3DModule(flow.unittest.TestCase):
# Disable this test for unknown error
# @autotest(n=3)
# def test_nn_functional_conv3d(test_case):
# device = random_device()
# img = torch.ones((1, 3, 224, 224, 224), requires_grad=True).to(device)
# kernel = torch.ones((6, 3, 3, 3, 3), requires_grad=True).to(device)
# y = torch.nn.functional.conv3d(img, kernel)
# return y
@autotest(n=3)
def test_nn_functional_conv3d(test_case):
device = random_device()
img = torch.ones((1, 3, 16, 16, 16), requires_grad=True).to(device)
kernel = torch.ones((6, 3, 3, 3, 3), requires_grad=True).to(device)
y = torch.nn.functional.conv3d(img, kernel)
return y

@autotest(n=10)
def test_conv3d_with_random_data(test_case):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import os
import warnings
import gc

import numpy as np
import oneflow as flow
Expand Down Expand Up @@ -843,6 +844,10 @@ def __eq__(self, other):
else:
return self.pytorch == other

def __del__(self):
# force running gc to avoid the periodic gc related to metaclass
gc.collect()


dual_modules_to_test = []
dual_objects_to_test = []
Expand Down Expand Up @@ -983,6 +988,8 @@ def new_f(test_case, *args, **kwargs):
loop_limit = successful_runs_needed * 20
current_run = 0
while successful_runs_needed > 0:
# force running gc to avoid the periodic gc related to metaclass
gc.collect()
clear_note_fake_program()
if current_run > loop_limit:
raise ValueError(
Expand Down

0 comments on commit ae13b04

Please sign in to comment.