From 90fc14f12d44747998abb92aa6bfa20532133151 Mon Sep 17 00:00:00 2001 From: Mark Richardson Date: Thu, 28 Jun 2018 12:49:54 -0700 Subject: [PATCH] Fix broken _assert_no_grad import --- pytorch_binding/warpctc_pytorch/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_binding/warpctc_pytorch/__init__.py b/pytorch_binding/warpctc_pytorch/__init__.py index 5f7cf74..4405140 100644 --- a/pytorch_binding/warpctc_pytorch/__init__.py +++ b/pytorch_binding/warpctc_pytorch/__init__.py @@ -2,11 +2,16 @@ import warpctc_pytorch as warp_ctc from torch.autograd import Function from torch.nn import Module -from torch.nn.modules.loss import _assert_no_grad from ._warp_ctc import * +def _assert_no_grad(tensor): + assert not tensor.requires_grad, \ + "gradients only computed for acts - please " \ + "mark other tensors as not requiring gradients" + + class _CTC(Function): @staticmethod def forward(ctx, acts, labels, act_lens, label_lens, size_average=False,