Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

PDF operators for the random samplers, and also the Dirichlet #14617

Merged
merged 1 commit into from
Jul 19, 2019

Conversation

david-seiler
Copy link
Contributor

Description

This PR replaces #14579; when I retargeted that from 1.3.x to master, the Jenkins CI build got confused somehow, and refuses to start new test runs (though the Travis build was fine). All the comments from the review of that PR should be addressed in this changeset.

This PR adds operators for computing the densities of samples drawn from any of the various distributions defined in operator/random, as well as their gradients, plus also the Dirichlet even though we don't yet have a sampler for it. There are lots of changes to test_random.py to test each PDF alongside its distribution; aside from that, the patch should be entirely stand-alone. See pdf_op.cc for more-detailed description strings.

Checklist

Essentials

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@piyushghai
Copy link
Contributor

Thanks for migrating your PR from v1.3.x branch to master.
@mxnet-label-bot Add [pr-awaiting-review, operator]

@david-seiler
Copy link
Contributor Author

The Python3 TensorRT GPU test failed: http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/mxnet-validation%2Funix-gpu/detail/PR-14617/1/pipeline

That failure doesn't look related to the PR; first make warns that "TensorRT not enabled by default. Please set the MXNET_USE_TENSORRT environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) to enable." Then a bit later, Cuda initialization fails with error 35. I made a trivial change to rerun the tests, and the same thing happened again (along with a numeric differentiation failure that I'll investigate). Is that test known to be flaky?

@david-seiler
Copy link
Contributor Author

The build checks all pass now. What's the next step to get this merged?

grad_nodes = ['v1', 'v2'] if symbdic['discrete'] else ['v0', 'v1', 'v2']
check_numeric_gradient(test_pdf, [un1, p1, p2], grad_nodes=grad_nodes, atol=backw_atol, rtol=backw_rtol, dtype=dtype)

@with_seed(1000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why the seed has been fixed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even in float64, the numeric gradient for the Dirichlet is a little unreliable: sometimes it diverges and generate Infs or NaNs. Without a fixed seed it would still pass almost all of the time, but this is safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I said that, but all the tests were against mxnet 1.3.x. I retested against 1.5 and couldn't reproduce any of the failures, so I've removed the explicit seed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the explicit seed is still there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, so it was. But now it's not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@david-seiler It's actually still there

@roywei
Copy link
Member

roywei commented Apr 29, 2019

@apeforest @eric-haibin-lin could you help review? thanks!

@asmushetzel
Copy link
Contributor

Would be really good if we can get this in. There are multiple teams that I know that will benefit from it. It also addresses most requested features from #12932.


alpha = alpha.reshape((2, 2, num_classes))

for dtype in [np.float32, np.float64]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why fp16 is not being tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can have fp16 but the tolerances have to be really loose, on the order of 5e-1, to get scipy and the symbolic forward to agree reliably. I've put it in, but it's something real users should be a little careful with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: this worked for me locally, but didn't consistently pass in the Jenkins builds so I've taken it back out. It's not really a sensible thing to do anyway; the Dirichlet involves a sum of lgammas of samples, it's just not going to be very stable in f16 no matter what you do. If users want something like that, they should probably use higher precision internally and then downcast at the end.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay.

But if you did want to test those, you could set different tolerance levels for different dtypes with something like this - rtol = 5e-1 if dtype is np.float16 else 1e-4( the values are random)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pardon me, missed this earlier. You can do that, but the thinking here is that you shouldn't: maybe there's some user somewhere who really knows what they're doing and truly wants a float16 Dirichlet, but in the common case it's a bad idea that should be avoided.

@anirudhacharya
Copy link
Member

Apologies for the delay and thanks for this contribution. It is plenty useful.

I have put in a few comments but the PR looks good to me for most part. I have marked some of the committers to review/merge this PR.

@eric-haibin-lin @szha @reminisce @haojin2 can you please take a look.

Copy link
Contributor

@aaronmarkham aaronmarkham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems pretty sparse on the documentation front.

Copy link
Contributor

@haojin2 haojin2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some cosmetic issues first, taking a look at the backend code now

backw_atol = 1e-5 if dtype == np.float64 else 1e-3
backw_rtol = 1e-4 if dtype == np.float64 else 5e-2
for use_log in [False, True]:
print("use_log",use_log)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Remove this print

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's gone.

check_with_device(mx.context.current_context(), 'float64')
check_with_device(mx.context.current_context(), np.float16)
check_with_device(mx.context.current_context(), np.float32)
check_with_device(mx.context.current_context(), np.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

for dtype in [np.float16, np.float32, np.float64]:
    check_with_device(mx.context.current_context(), dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

res = results if use_log else np.exp(results)
check_symbolic_forward(test_pdf, [samples, alpha], [res], atol=forw_atol, rtol=forw_rtol, dtype=dtype)
if dtype == 'float64':
check_numeric_gradient(test_pdf, [samples, alpha], numeric_eps=1e-7, atol=backw_atol, rtol=backw_rtol, dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that numeric gradient is not checked for fp32, what is the reason behind that? I think we should have coverage for the most commonly used data type. And can you also add the symbolic backward check using check_symbolic_backward ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we don't have an independent source of truth for the gradients the way we do for the PDF itself. We check our symbolic forward against the densities given by scipy, but scipy doesn't have functions for analytic gradients of the PDFs, just the same kinds of tools for numeric differentiation that we've got, so there's nothing for check_symbolic_backward to check against.

Similarly, we've found that the numeric gradient is most reliable in float64. The closed-form gradient functions we've written are -- assuming they're written correctly at all -- more accurate than numeric approximation can be. Checking the gradient numerically in float64 provides a lot of evidence about whether the gradient functions are written correctly, but checking float32 doesn't add very much more; discrepancies are more likely to be caused by numeric errors in the approximation than in our code.

const IType2 *cur_alpha = alpha+index*k;
const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
DType sum_alpha(0);
for ( int j = 0; j < k; ++j ) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Don't need the extra spaces within the parentheses

for (int j = 0; j < k; ++j) {

Same for several other places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

const std::vector<TBlob>& outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), pnum+3);
CHECK_EQ(outputs.size(), pnum+1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: space around +

  CHECK_EQ(inputs.size(), pnum + 3);
  CHECK_EQ(outputs.size(), pnum + 1);

Same applies for some other operators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots more operator whitespace now. Might be useful to have this in the linter.

@david-seiler david-seiler force-pushed the master branch 6 times, most recently from 2d1d971 to fb4decb Compare May 7, 2019 16:49
@david-seiler
Copy link
Contributor Author

commenting to reopen, I was trying to debug a mysterious linker failure in unix-gpu, which does not occur in mainline but which nevertheless seems to exist independent of any of these changes.

@david-seiler david-seiler reopened this Jun 3, 2019
@david-seiler david-seiler force-pushed the master branch 9 times, most recently from d826ca0 to fd73f6c Compare June 4, 2019 13:33
@@ -24,7 +24,7 @@ $env:MXNET_HOME=[io.path]::combine($PSScriptRoot, 'mxnet_home')

C:\Python37\Scripts\pip install -r tests\requirements.txt
C:\Python37\python.exe -m nose -v --with-timer --timer-ok 1 --timer-warning 15 --timer-filter warning,error --with-xunit --xunit-file nosetests_unittest.xml tests\python\unittest
if (! $?) { Throw ("Error running unittest") }
if (! $?) { Throw ("Error running unittest) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you want to remove the "?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, I was factoring some error-handling code out to PR-15147 and got a little too aggressive. Good catch, fixed now.

@vandanavk
Copy link
Contributor

@samskalicky @apeforest for review

…r (plus also the PDF of the Dirichlet). Supports probabilities and log-probabilities, as well as gradients.
@sxjscience
Copy link
Member

LGTM

@sxjscience
Copy link
Member

This also has not checked the grad_req=kAddTo case. Let's revise it later @haojin2

@sxjscience sxjscience merged commit b887c06 into apache:master Jul 19, 2019
@ChaiBapchya
Copy link
Contributor

ChaiBapchya commented Jul 25, 2019

Curious to know if pdf operators aren't supposed to have NDArray API? Currently they're only in Symbol API right? @sxjscience @david-seiler
Realized it has an NDArray API. Just not mentioned in the docs.
Thanks.

@ChaiBapchya ChaiBapchya mentioned this pull request Jul 25, 2019
3 tasks
anirudhacharya pushed a commit to anirudhacharya/mxnet that referenced this pull request Aug 20, 2019
…r (plus also the PDF of the Dirichlet). Supports probabilities and log-probabilities, as well as gradients. (apache#14617)
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.