Skip to content

Commit

Permalink
handling FFTW wisdom in FFTWEngine
Browse files Browse the repository at this point in the history
  • Loading branch information
adematti committed Jan 21, 2022
1 parent 2b41a89 commit 262d90b
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 93 deletions.
4 changes: 2 additions & 2 deletions bin/config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ mesh:
nmesh: 512 # mesh size (int or list of 3 ints)
boxsize: # box size (floot or list of 3 floats)
boxcenter: # box center
wrap: False # whether to wrap positions using periodic boundary conditions over the box
dtype: f4 # mesh data-type for f4 (float32) or f8 (float64)
fft_engine: 'fftw' # FFT engine, either 'numpy' or 'fftw' (recommended)
fft_plan: 'estimate' # FFT planning for FFTW engine
wrap: False # whether to wrap positions using periodic boundary conditions over the box
fft_wisdom: # wisdom for 'fftw' FFT engine (optional)
save_fft_wisdom: # where to save (and try to load) wisdom for 'fftw' FFT engine (optional)
60 changes: 38 additions & 22 deletions pyrecon/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(self, value=None, info=None, nthreads=None, attrs=None, **kwargs):
Mesh information (boxsize, boxcenter, nmesh, etc.),
copied and updated with ``kwargs``.
nthreads : int
Number of threads to use in mesh calculations.
nthreads : int, default=None
Number of threads to use in mesh calculations; defaults to OpenMP's default.
attrs : dict
Dictionary of other attributes.
Expand Down Expand Up @@ -848,7 +848,7 @@ def prod_sum(self, arrays, exp=1):
raise MeshError('Issue with prod_sum')


class BaseFFTEngine(object):
class BaseFFTEngine(BaseClass):
"""
Base engine for fast Fourier transforms.
FFT engines should extend this class, by (at least) implementing:
Expand Down Expand Up @@ -895,7 +895,7 @@ def __init__(self, shape, nthreads=None, type_complex=None, type_real=None, herm
If not provided, use ``type_complex`` instead.
"""
if nthreads is None:
self.nthreads = int(os.environ.get('OMP_NUM_THREADS',1))
self.nthreads = int(os.environ.get('OMP_NUM_THREADS', '1'))
else:
self.nthreads = nthreads
self.shape = tuple(shape)
Expand Down Expand Up @@ -961,7 +961,7 @@ class FFTWEngine(BaseFFTEngine):

"""FFT engine based on :mod:`pyfftw`."""

def __init__(self, shape, nthreads=None, wisdom=None, plan='measure', **kwargs):
def __init__(self, shape, nthreads=None, wisdom=None, save_wisdom=None, plan='measure', **kwargs):
"""
Initialize :mod:`pyfftw` engine.
Expand All @@ -984,9 +984,15 @@ def __init__(self, shape, nthreads=None, wisdom=None, plan='measure', **kwargs):
Number of threads.
wisdom : string, tuple, default=None
:mod:`pyfftw` wisdom, used to accelerate further FFTs.
Precomputed :mod:`pyfftw` wisdom, used to accelerate FFTs.
If a string, should be a path to previously saved FFT wisdom (with :func:`numpy.save`).
If a tuple, directly corresponds to the wisdom.
By default the wisdom given in ``save_wisdom`` will be loaded, if exists.
save_wisdom : bool, string, default=None
If not ``None``, path where to save the wisdom.
If ``True``, the wisdom will be saved in the default path:
'wisdom.shape-{shape[0]}-{shape[1]}-{shape[2]}.type-{type}.nthreads-{nthreads}.npy'.
plan : string, default='measure'
Choices are ['estimate', 'measure', 'patient', 'exhaustive'].
Expand All @@ -998,35 +1004,46 @@ def __init__(self, shape, nthreads=None, wisdom=None, plan='measure', **kwargs):
"""
if pyfftw is None:
raise NotImplementedError('Install pyfftw to use {}'.format(self.__class__.__name__))
super(FFTWEngine,self).__init__(shape,nthreads=nthreads,**kwargs)
super(FFTWEngine, self).__init__(shape, nthreads=nthreads, **kwargs)
plan = plan.lower()
allowed_plans = ['estimate', 'measure', 'patient', 'exhaustive']
if plan not in allowed_plans:
raise MeshError('Plan {} unknown'.format(plan))
plan = 'FFTW_{}'.format(plan.upper())

if isinstance(wisdom, str):
if os.path.isfile(wisdom):
wisdom = tuple(np.load(wisdom))
dtype = self.type_real if self.hermitian else self.type_complex
wisdom_fn = 'wisdom.shape-{}.type-{}.nthreads-{:d}.npy'.format('-'.join(['{:d}'.format(s) for s in self.shape]), dtype.name, self.nthreads)
# Should we save wisdom?
if save_wisdom and isinstance(save_wisdom, str):
wisdom_fn = save_wisdom
save_wisdom = bool(save_wisdom)

if wisdom is None:
try:
wisdom = np.load(wisdom_fn)
pyfftw.import_wisdom(wisdom)
except:
pass
else:
wisdom = None
if wisdom is not None:
pyfftw.import_wisdom(wisdom)
self.log_info('Loading wisdom from {}.'.format(wisdom_fn))
elif isinstance(wisdom, str):
self.log_info('Loading wisdom from {}.'.format(wisdom))
wisdom = tuple(np.load(wisdom))
else:
pyfftw.forget_wisdom()
if self.hermitian:
fftw_f = pyfftw.empty_aligned(self.shape, dtype=self.type_real, order='C')
else:
fftw_f = pyfftw.empty_aligned(self.shape, dtype=self.type_complex, order='C')
pyfftw.import_wisdom(wisdom)

fftw_f = pyfftw.empty_aligned(self.shape, dtype=dtype, order='C')
fftw_fk = pyfftw.empty_aligned(self.hshape, dtype=self.type_complex, order='C')
self.flags = (plan,)
v = pyfftw.FFTW(fftw_f,fftw_fk,axes=range(self.ndim), direction='FFTW_FORWARD', flags=self.flags, threads=self.nthreads)
self.fftw_forward_object = pyfftw.FFTW(fftw_f, fftw_fk, axes=range(self.ndim), direction='FFTW_FORWARD', flags=self.flags, threads=self.nthreads)
self.fftw_backward_object = pyfftw.FFTW(fftw_fk, fftw_f, axes=range(self.ndim), direction='FFTW_BACKWARD', flags=self.flags, threads=self.nthreads)
# We delete these instances to save memory, see note above
self.fftw_forward_object, self.fftw_backward_object = None, None
# allow the wisdom to be accessed from outside
self.fft_wisdom = pyfftw.export_wisdom()
# Allow the wisdom to be accessed from outside
self.wisdom = pyfftw.export_wisdom()
if save_wisdom:
self.log_info('Saving wisdom to {}.'.format(wisdom_fn))
np.save(wisdom_fn, self.wisdom)

def forward(self, fun):
"""Return forward transform of ``fun``."""
Expand Down Expand Up @@ -1059,7 +1076,6 @@ def backward(self, fun, destroy_input=True):
else:
output_array = pyfftw.empty_aligned(self.shape, dtype=self.type_complex, order='C')
if self.fftw_backward_object is None:
#fftw_backward_object = pyfftw.FFTW(fun,output_array,axes=range(self.ndim),direction='FFTW_BACKWARD',flags=self.flags,threads=self.nthreads)
fftw_backward_object = pyfftw.FFTW(input_array, output_array, axes=range(self.ndim), direction='FFTW_BACKWARD', flags=self.flags, threads=self.nthreads)
toret = fftw_backward_object(normalise_idft=True)
else:
Expand Down
38 changes: 14 additions & 24 deletions pyrecon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class BaseReconstruction(BaseClass):
boxsize, boxcenter, cellsize, offset : array
Mesh properties; see :class:`MeshInfo`.
"""
def __init__(self, f=0., bias=1., los=None, fft_engine='numpy', fft_wisdom=None, fft_plan=None, wrap=False, **kwargs):
def __init__(self, f=0., bias=1., los=None, fft_engine='numpy', fft_wisdom=None, save_fft_wisdom=None, fft_plan='measure', wrap=False, **kwargs):
"""
Initialize :class:`BaseReconstruction`.
Expand All @@ -82,9 +82,16 @@ def __init__(self, f=0., bias=1., los=None, fft_engine='numpy', fft_wisdom=None,
We strongly recommend using 'fftw' for multithreaded FFTs.
fft_wisdom : string, tuple, default=None
Wisdom for FFTW, if ``fft_engine`` is 'fftw'.
Optionally, wisdom for FFTW, if ``fft_engine`` is 'fftw'.
If a string, should be a path to previously saved FFT wisdom (with :func:`numpy.save`).
If a tuple, directly corresponds to the wisdom.
By default the wisdom given in ``save_fft_wisdom`` will be loaded, if exists.
fft_plan : string, default=None
save_fft_wisdom : bool, string, default=None
If not ``None``, path where to save the wisdom for FFTW.
If ``True``, the wisdom will be saved in the default path: f'wisdom.shape-{nmesh[0]}-{nmesh[1]}-{nmesh[2]}.type-{type}.nthreads-{nthreads}.npy'.
fft_plan : string, default='measure'
Only used for FFTW. Choices are ['estimate', 'measure', 'patient', 'exhaustive'].
The increasing amount of effort spent during the planning stage to create the fastest possible transform.
Usually 'measure' is a good compromise.
Expand All @@ -103,28 +110,11 @@ def __init__(self, f=0., bias=1., los=None, fft_engine='numpy', fft_wisdom=None,
self.info = self.mesh_randoms.info
self.set_los(los)
self.log_info('Using mesh {}.'.format(self.mesh_data))
kwargs = {}
if fft_wisdom is None:
# set to default – if this file doesn't exist it will be ignored
default_wisdom_fn = os.path.join(os.getcwd(), f'wisdom.{self.mesh_data.nmesh[0]}.{self.mesh_data.nthreads}.npy')
kwargs['wisdom'] = default_wisdom_fn
else:
kwargs['wisdom'] = fft_wisdom
if fft_plan is not None: kwargs['plan'] = fft_plan
kwargs['hermitian'] = False
self.mesh_data.set_fft_engine(fft_engine, **kwargs)
self.log_info('Using {:d} nthreads.'.format(self.mesh_data.nthreads))
self.mesh_data.set_fft_engine(fft_engine, wisdom=fft_wisdom, save_wisdom=save_fft_wisdom, plan=fft_plan, hermitian=False)
self.mesh_randoms.set_fft_engine(self.mesh_data.fft_engine)
if fft_engine == 'fftw':
# allow the wisdom to be accessed from outside if necessary
self.wisdom = self.mesh_data.fft_engine.fft_wisdom
# generating the wisdom can be a large fraction of the total FFT time, so we should save it
if isinstance(fft_wisdom, str):
# if fft_wisdom contains a file name we save it there, overwriting existing file if necessary
np.save(fft_wisdom, self.wisdom)
if fft_wisdom is None:
# if no filename was provided, save it to a default file location
np.save(default_wisdom_fn, self.wisdom)
# if fft_wisdom was a tuple containing the wisdom itself, don't save anything
# Allow the wisdom to be accessed from outside if necessary
self.fft_wisdom = getattr(self.mesh_data.fft_engine, 'wisdom', None)

@property
def beta(self):
Expand Down
2 changes: 2 additions & 0 deletions pyrecon/tests/config_iterativefft_particle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ cosmology:
mesh:
nmesh: 128
dtype: f8
#fft_engine: 'fftw'
#save_fft_wisdom: 'wisdom.iterative_fft_particle.npy'
14 changes: 0 additions & 14 deletions pyrecon/tests/test_iterative_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,6 @@ def test_mem():
recon.run()
mem('recon') # 3 meshes

def test_wisdom():
recon = IterativeFFTReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure')
# wisdom created and accessible
assert recon.wisdom
default_wisdom_fn = os.path.join(os.getcwd(), f'wisdom.{recon.mesh_data.nmesh[0]}.{recon.mesh_data.nthreads}.npy')
print(default_wisdom_fn)
# wisdom was written to default wisdom file
assert os.path.isfile(default_wisdom_fn)
new_wisdom_fn = 'new_wisdomfile.npy'
recon = IterativeFFTReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure', fft_wisdom=new_wisdom_fn)
# wisdom written to custom file
assert os.path.isfile(new_wisdom_fn)
# wisdom written to both files is the same
assert tuple(np.load(default_wisdom_fn)) == tuple(np.load(new_wisdom_fn))

def test_iterative_fft_wrap():
size = 100000
Expand Down
37 changes: 28 additions & 9 deletions pyrecon/tests/test_iterative_fft_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,36 @@ def test_mem():
recon.run()
mem('recon') # 3 meshes


def test_wisdom():
recon = IterativeFFTParticleReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure')
# wisdom created and accessible
assert recon.wisdom
default_wisdom_fn = os.path.join(os.getcwd(), f'wisdom.{recon.mesh_data.nmesh[0]}.{recon.mesh_data.nthreads}.npy')
print(default_wisdom_fn)
# wisdom was written to default wisdom file

def remove(fn):
try: os.remove(fn)
except OSError: pass

default_wisdom_fn = 'wisdom.shape-64-64-64.type-complex128.nthreads-1.npy'
remove(default_wisdom_fn)

recon = IterativeFFTParticleReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure', nthreads=1)
# Wisdom created and accessible
assert getattr(recon, 'fft_wisdom', None)
assert not os.path.isfile(default_wisdom_fn)

recon = IterativeFFTParticleReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure', save_fft_wisdom=True, nthreads=1)
# Wisdom created and accessible
# Wisdom was written to default wisdom file
assert os.path.isfile(default_wisdom_fn)

new_wisdom_fn = 'new_wisdomfile.npy'
recon = IterativeFFTParticleReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure', fft_wisdom=new_wisdom_fn)
# wisdom written to custom file
remove(new_wisdom_fn)
recon = IterativeFFTParticleReconstruction(f=0.8, bias=2, los='z', boxsize=1000, boxcenter=500, nmesh=64, fft_engine='fftw', fft_plan='measure', save_fft_wisdom=new_wisdom_fn, nthreads=1)
# Wisdom written to custom file
assert os.path.isfile(new_wisdom_fn)
# wisdom written to both files is the same
# Wisdom written to both files is the same
assert tuple(np.load(default_wisdom_fn)) == tuple(np.load(new_wisdom_fn))
remove(default_wisdom_fn)
remove(new_wisdom_fn)


def test_iterative_fft_particle_wrap():
size = 100000
Expand Down Expand Up @@ -376,6 +392,9 @@ def test_script_no_randoms(data_fn, output_data_fn):
script_output_randoms_fn = os.path.join(catalog_dir,'script_randoms_rec.fits')

#test_mem()
test_script(data_fn,randoms_fn,script_output_data_fn,script_output_randoms_fn)
exit()
test_wisdom()
test_no_nrandoms()
test_dtype()
test_los()
Expand Down
70 changes: 48 additions & 22 deletions pyrecon/tests/test_mesh.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import tempfile

import pytest
import numpy as np

Expand Down Expand Up @@ -102,35 +105,59 @@ def test_fft():
try: import pyfftw
except ImportError: pyfftw = None

for dtype in ['f4','f8']:
mesh = RealMesh(value=1.,boxsize=1000.,boxcenter=0.,nmesh=4,dtype='f8')
cmesh = mesh.to_complex()
from pyrecon.mesh import NumpyFFTEngine
assert isinstance(mesh.fft_engine, NumpyFFTEngine)
if pyfftw is not None:
cmesh = mesh.to_complex(engine='fftw', plan='estimate')
from pyrecon.mesh import FFTWEngine
assert isinstance(cmesh.fft_engine, FFTWEngine)

fft_engine = cmesh.fft_engine
cmesh = cmesh + 1
assert cmesh.fft_engine is fft_engine
mesh = cmesh.to_real()
assert mesh.fft_engine is fft_engine
mesh.smooth_gaussian(15., method='fft')
assert mesh.fft_engine is fft_engine

def remove(fn):
try: os.remove(fn)
except OSError: pass

for dtype in ['f4', 'f8']:
mesh = RealMesh(boxsize=1000.,boxcenter=0.,nmesh=4,dtype=dtype)
mesh.value = np.random.uniform(0.,1.,mesh.shape)
lkwargs = [{'engine':'numpy'}]
if pyfftw is not None:
lkwargs += [{'engine':'fftw'}, {'engine':'fftw', 'plan':'estimate'}]
lkwargs += [{'engine':'fftw', 'save_wisdom':True}, {'engine':'fftw', 'save_wisdom':'new_wisdomfile.npy'}]

for hermitian in [True, False]:
for kwargs in [{'engine':'numpy'}] + ([{'engine':'fftw'},{'engine':'fftw','plan':'estimate'}] if pyfftw is not None else []):
for kwargs in lkwargs:
mesh_copy = mesh.value.copy()
mesh1 = mesh.to_complex(hermitian=hermitian,**kwargs)
engine = kwargs.get('engine', None)
if engine == 'fftw':
dtype = mesh.dtype if hermitian else (1j*np.empty(0, dtype=mesh.dtype)).dtype
default_wisdom_fn = 'wisdom.shape-{}.type-{}.nthreads-{:d}.npy'.format('-'.join(['{:d}'.format(n) for n in mesh.nmesh]), dtype.name, mesh.nthreads)
remove(default_wisdom_fn)
mesh1 = mesh.to_complex(hermitian=hermitian, **kwargs)
assert np.all(mesh.value == mesh_copy)
mesh_copy = mesh1.value.copy()
mesh2 = mesh1.to_real()
assert np.all(mesh1.value == mesh_copy)
assert np.allclose(mesh2,mesh,atol=1e-5)

mesh = RealMesh(value=1.,boxsize=1000.,boxcenter=0.,nmesh=4,dtype=dtype)
cmesh = mesh.to_complex()
from pyrecon.mesh import NumpyFFTEngine
assert isinstance(mesh.fft_engine, NumpyFFTEngine)
if pyfftw is not None:
cmesh = mesh.to_complex(engine='fftw', plan='estimate')
from pyrecon.mesh import FFTWEngine
assert isinstance(cmesh.fft_engine, FFTWEngine)
fft_engine = cmesh.fft_engine
cmesh = cmesh + 1
assert cmesh.fft_engine is fft_engine
mesh = cmesh.to_real()
assert mesh.fft_engine is fft_engine
mesh.smooth_gaussian(15., method='fft')
assert mesh.fft_engine is fft_engine
assert np.allclose(mesh2, mesh, atol=1e-5)
if engine == 'fftw':
save_wisdom = kwargs.get('save_wisdom', None)
if isinstance(save_wisdom, str):
assert os.path.isfile(save_wisdom)
remove(save_wisdom)
elif save_wisdom:
assert os.path.isfile(default_wisdom_fn)
remove(default_wisdom_fn)
else:
assert not os.path.isfile(default_wisdom_fn)


def test_hermitian():
Expand Down Expand Up @@ -278,11 +305,10 @@ def test_timing():
#test_timing()
#test_misc()
#test_timing()
#test_fft()
test_info()
test_cic()
test_finite_difference_cic()
test_smoothing()
test_fft()
test_hermitian()
test_smoothing()
test_misc()

0 comments on commit 262d90b

Please sign in to comment.