Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical precision issues without TF32 #2

Open
prutschman-iv opened this issue Jul 19, 2024 · 2 comments
Open

Numerical precision issues without TF32 #2

prutschman-iv opened this issue Jul 19, 2024 · 2 comments

Comments

@prutschman-iv
Copy link

I was seeing what seemed like poor numerical performance (without TF32) compared to the scipy implementation of RBFInterpolator, so I made the following test case. It creates a regularly spaced grid of points, displaces them randomly but not so much as to create "pinches", then compares the interpolated values at the regular grid points. With infinite precision a thin plate spline interpolator would return identical values.

import numpy as np

import sys
import scipy
import torch
import torchrbf

assert not torch.backends.cuda.matmul.allow_tf32
torch.set_default_device('cuda')

print('python ', sys.version)
print('torch ', torch.__version__)

for k in range(1,8):
    print(f"{(2**k)**2} control points")
    pts = np.indices((2**k,2**k)).reshape(2,-1).T.astype(np.float32)
    pts_offset = (pts + np.random.uniform(-0.1,0.1,size=pts.shape)).astype(np.float32)

    tpts = torch.tensor(pts)
    tpts_offset = torch.tensor(pts_offset)

    rbf = scipy.interpolate.RBFInterpolator(pts, pts_offset)
    errs = rbf(pts)-pts_offset
    error_mags = np.hypot(*errs.T)
    print('  scipy   \t', error_mags.max())

    trbf = torchrbf.RBFInterpolator(tpts, tpts_offset, device='cuda')
    errs = trbf(tpts)-tpts_offset
    error_mags = torch.hypot(*errs.T)
    print('  torchrbf\t', error_mags.max().cpu().numpy())

When I run this on my Windows machine with a GTX 4090, I get the following:

python  3.11.9 (tags/v3.11.9:de54cf5, Apr  2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]
torch  2.3.1+cu121
4 control points
  scipy   	 0.0
  torchrbf	 2.3964506e-07
16 control points
  scipy   	 8.95090418262362e-16
  torchrbf	 1.0662403e-06
64 control points
  scipy   	 2.6615730177631208e-14
  torchrbf	 1.5168344e-05
256 control points
  scipy   	 6.629405610737136e-13
  torchrbf	 0.00014111983
1024 control points
  scipy   	 1.3932322615057043e-11
  torchrbf	 0.007374375
4096 control points
  scipy   	 2.662584573260454e-10
  torchrbf	 0.15812844
16384 control points
  scipy   	 3.6122964516631894e-09
  torchrbf	 179.04927

Do you have any suggestions or advice on improving the situation? 16k control points is a bit on the excessive side, but I have a practical application that could easily use on the order of 4k points, and torchrbf errors at this level are on the order of the grid displacement.

@ArmanMaesumi
Copy link
Owner

After quite a bit of debugging, it dawned on me that numpy's default dtype is float64, whereas pytorch's default is of course float32. Admittedly this is quite a big oversight on my part!

To confirm, I manually converted all (relevant) internal tensors to torch.float64, and indeed the precision is now in a reasonable margin of scipy's.

Perhaps I should include an optional argument that forces torchrbf to use higher precision internally -- in most cases it is probably not necessary, so the default can be set to the current behavior (float32). And of course GPU performance tends to fall drastically when using higher precision so you'll want to avoid it when possible.

If you want to quickly hack together a local fix, you can CTRL + F ".float()" in torchrbf/RBFInterpolator.py and change those all to ".double()". You'll also want to make sure your input tensors (data coordinates and data values) are converted.

I'm currently a bit occupied so I won't have time to push a patch for now.

@prutschman-iv
Copy link
Author

Thank you, I will give your suggestion a try!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants