Skip to content

Commit

Permalink
Increase feature.
Browse files Browse the repository at this point in the history
Format.

Delete cleanup since it's already done in conftest.py#scope_function

Change features to pass tests.
  • Loading branch information
KazukiYoshiyama-sony committed Oct 23, 2019
1 parent 68b156b commit 1b0721c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 15 deletions.
17 changes: 6 additions & 11 deletions python/test/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ def conv(x, maps=8, name="conv"):
h = conv(h, maps=4, name="conv1")
h = F.max_pooling(h, (2, 2))
h = conv(h, maps=4, name="conv2")
h = conv(h, maps=8, name="conv3") if not shared else conv(h, maps=4, name="conv2")
h = conv(h, maps=8, name="conv3") if not shared else conv(
h, maps=4, name="conv2")
h = F.average_pooling(h, h.shape[2:])
h = PF.affine(h, 10)
return h


@pytest.mark.parametrize("seed", [311])
@pytest.mark.parametrize("ctx", ctx_list)
@pytest.mark.parametrize("auto_forward", [True, False])
Expand Down Expand Up @@ -89,8 +91,6 @@ def test_resnet_expansion(seed, ctx, auto_forward, flag_grad_outputs, shared):
assert_allclose(
inp.g, grad.d, atol=1e-6)

# Clean up
nn.set_auto_forward(False)

@pytest.mark.parametrize("seed", [311])
@pytest.mark.parametrize("ctx", ctx_list)
Expand Down Expand Up @@ -137,8 +137,6 @@ def test_multiple_objectives(seed, ctx, auto_forward):
assert_allclose(
inp.g, grad.d, atol=1e-6)

# Clean up
nn.set_auto_forward(False)

@pytest.mark.parametrize("seed", [311])
@pytest.mark.parametrize("ctx", ctx_list)
Expand Down Expand Up @@ -187,9 +185,6 @@ def test_grad_outputs(seed, ctx, auto_forward, type_grad_outputs):
assert_allclose(
inp.g, grad.d, atol=1e-6)

# Clean up
nn.set_auto_forward(False)


@pytest.mark.parametrize("seed", [311])
@pytest.mark.parametrize("ctx", ctx_list)
Expand All @@ -202,20 +197,23 @@ def add(x, derivate=0):
return 3 * np.ones_like(x)
if derivate == 2:
return np.zeros_like(x)

def sub(x, derivate=0):
if derivate == 0:
return x - x - x
if derivate == 1:
return -1 * np.ones_like(x)
if derivate == 2:
return np.zeros_like(x)

def mul(x, derivate=0):
if derivate == 0:
return x * x * x
if derivate == 1:
return 3 * x ** 2
if derivate == 2:
return 6 * x

def div(x, derivate=0):
if derivate == 0:
return x / x / x
Expand All @@ -241,6 +239,3 @@ def div(x, derivate=0):
# Second-order gradient
dy_dx[0].backward()
assert_allclose(x.g, math_type(xd, 2))

# Clean up
nn.set_auto_forward(False)
7 changes: 4 additions & 3 deletions python/test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,18 @@ def test_graph_model(model, seed):
z3 = PF.affine(z2, 5)
elif model == "recurrent":
with nn.parameter_scope('fc1'):
z = PF.affine(x, 4)
z = PF.affine(x, 8)
z2 = F.relu(z, inplace=True)
h = z2
for _ in range(2):
with nn.parameter_scope('fc2'):
h = PF.affine(h, 4)
h = PF.affine(h, 8)
h = F.relu(h, inplace=True)
with nn.parameter_scope('fc3'):
z3 = PF.affine(h, 5)
elif model == "convolution":
with nn.parameter_scope('conv1'):
z = PF.convolution(x, 16, (2, 2))
z = PF.convolution(x, 3, (2, 2))
z2 = F.relu(z, inplace=True)
with nn.parameter_scope('fc2'):
z3 = PF.affine(z2, 5)
Expand Down Expand Up @@ -295,6 +295,7 @@ def backward_post_hook(f):
nn.forward_all((y, z), function_pre_hook=lambda f: None,
function_post_hook=lambda f: None)


@pytest.mark.parametrize("seed", [313])
def test_shared_variable_on_same_function(seed):
rng = np.random.RandomState(313)
Expand Down
3 changes: 2 additions & 1 deletion src/nbla/computation_graph/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class BackwardCallback {
if (!inputs[i]->need_grad_state())
continue;

// If memset with 0 is reserved, accum is not used. For shared case, the first is only non-accum.
// If memset with 0 is reserved, accum is not used. For shared case, the
// first is only non-accum.
auto array = inputs[i]->variable()->grad()->array();
if (array->zeroing()) {
bool input_shared = false;
Expand Down

0 comments on commit 1b0721c

Please sign in to comment.