diff --git a/pytorch_binding/warpctc_pytorch/__init__.py b/pytorch_binding/warpctc_pytorch/__init__.py index 5f7cf74..4405140 100644 --- a/pytorch_binding/warpctc_pytorch/__init__.py +++ b/pytorch_binding/warpctc_pytorch/__init__.py @@ -2,11 +2,16 @@ import warpctc_pytorch as warp_ctc from torch.autograd import Function from torch.nn import Module -from torch.nn.modules.loss import _assert_no_grad from ._warp_ctc import * +def _assert_no_grad(tensor): + assert not tensor.requires_grad, \ + "gradients only computed for acts - please " \ + "mark other tensors as not requiring gradients" + + class _CTC(Function): @staticmethod def forward(ctx, acts, labels, act_lens, label_lens, size_average=False,