Skip to content

Commit

Permalink
Merge pull request #77 from hpc4cmb/madam_output
Browse files Browse the repository at this point in the history
Add a unit test for Madam output. Force Madam to return destriped tim…
  • Loading branch information
tskisner authored Dec 13, 2016
2 parents 994f738 + c7fd725 commit 371aa89
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 23 deletions.
57 changes: 57 additions & 0 deletions tests/test_ops_madam.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_madam_gradient(self):
pars[ 'kfilter' ] = 'F'
pars[ 'run_submap_test' ] = 'F'
pars[ 'path_output' ] = self.mapdir
pars[ 'info' ] = 0

madam = OpMadam(params=pars, name='grad', dets=self.dets)
if madam.available:
Expand All @@ -118,3 +119,59 @@ def test_madam_gradient(self):
else:
print("libmadam not available, skipping tests")

def test_madam_output(self):
start = MPI.Wtime()

# add simple sky gradient signal
grad = OpSimGradient(nside=self.sim_nside)
grad.exec(self.data)

# make a simple pointing matrix
pointing = OpPointingHpix(nside=self.map_nside, nest=True)
pointing.exec(self.data)

handle = None
if self.comm.rank == 0:
handle = open(os.path.join(self.outdir,"out_test_madam_info"), "w")
self.data.info(handle)
if self.comm.rank == 0:
handle.close()

pars = {}
pars[ 'kfirst' ] = 'T'
pars[ 'iter_max' ] = 100
pars[ 'base_first' ] = 1.0
pars[ 'fsample' ] = self.rate
pars[ 'nside_map' ] = self.map_nside
pars[ 'nside_cross' ] = self.map_nside
pars[ 'nside_submap' ] = self.map_nside
pars[ 'write_map' ] = 'F'
pars[ 'write_binmap' ] = 'T'
pars[ 'write_matrix' ] = 'F'
pars[ 'write_wcov' ] = 'F'
pars[ 'write_hits' ] = 'T'
pars[ 'kfilter' ] = 'F'
pars[ 'run_submap_test' ] = 'F'
pars[ 'path_output' ] = self.mapdir
pars[ 'info' ] = 0

madam = OpMadam(params=pars, name='grad', name_out='destriped', dets=self.dets)
if madam.available:
tod = self.data.obs[0]['tod']
det = 'bore'
ref_in = tod.cache.reference('grad_'+det)
rms0 = np.std(ref_in)
ref_in[ref_in.size//2:] += 1e6 # Add an offset
rms1 = np.std(ref_in)

madam.exec(self.data)
stop = MPI.Wtime()
elapsed = stop - start
self.print_in_turns("Madam test took {:.3f} s".format(elapsed))

ref_out = tod.cache.reference('destriped_'+det)
rms2 = np.std(ref_out)
if rms1 < 0.9*rms2:
raise Exception('Destriped TOD does not have lower RMS')
else:
print("libmadam not available, skipping tests")
54 changes: 31 additions & 23 deletions toast/map/madam.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,18 @@ def __init__(self, params={}, timestamps_name=None, detweights=None,
self._apply_flags = apply_flags
self._params = params
if dets is not None:
self._dets = set( dets )
self._dets = set(dets)
else:
self._dets = None
self._mcmode = mcmode
if mcmode:
self._params['mcmode'] = True
else:
self._params['mcmode'] = False
if self._name_out is not None:
self._params['write_tod'] = True
else:
self._params['write_tod'] = False
self._cached = False
self._noisekey = noise

Expand Down Expand Up @@ -314,13 +318,13 @@ def exec(self, data):
if (ival.last >= local_offset
and ival.first < (local_offset + local_nsamp)):
local_start = ival.first - local_offset
local_stop = ival.last - local_offset + 1
local_stop = ival.last - local_offset
if local_start < 0:
local_start = 0
if local_stop > local_nsamp:
local_stop = local_nsamp
period_lengths.append(local_stop - local_start)
period_ranges.append((local_start, local_stop))
if local_stop > local_nsamp - 1:
local_stop = local_nsamp - 1
period_lengths.append(local_stop - local_start + 1)
period_ranges.append((local_start, local_stop + 1))
obs_period_ranges.append(period_ranges)

nperiod = len(period_lengths)
Expand Down Expand Up @@ -362,21 +366,24 @@ def exec(self, data):
# entries in the dictionary when the PSD actually changes
if self._noisekey in obs.keys():
nse = obs[self._noisekey]
if psdfreqs is None:
psdfreqs = nse.freq(detectors[0]).astype(np.float64).copy()
npsdbin = len(psdfreqs)
for d in range(ndet):
det = detectors[d]
check_psdfreqs = nse.freq(det)
if not np.allclose(psdfreqs, check_psdfreqs):
raise RuntimeError('All PSDs passed to Madam must have'
' the same frequency binning.')
psd = nse.psd(det)
if det not in psds:
psds[det] = [(0, psd)]
else:
if not np.allclose(psds[det][-1][1], psd):
psds[det] += [(timestamps[0], psd)]
if nse is not None:
if psdfreqs is None:
psdfreqs = nse.freq(detectors[0]).astype(
np.float64).copy()
npsdbin = len(psdfreqs)
for d in range(ndet):
det = detectors[d]
check_psdfreqs = nse.freq(det)
if not np.allclose(psdfreqs, check_psdfreqs):
raise RuntimeError(
'All PSDs passed to Madam must have'
' the same frequency binning.')
psd = nse.psd(det)
if det not in psds:
psds[det] = [(0, psd)]
else:
if not np.allclose(psds[det][-1][1], psd):
psds[det] += [(timestamps[0], psd)]

for d in range(ndet):
# Get the signal.
Expand Down Expand Up @@ -531,18 +538,19 @@ def exec(self, data):
self._cached = True

if self._name_out is not None:
global_offset = 0
for obs, period_ranges in zip(data.obs, obs_period_ranges):
tod = obs['tod']
nlocal = tod.local_samples[1]
for d, det in enumerate(detectors):
signal = np.ones(nlocal) * np.nan
offset = 0
offset = global_offset
for istart, istop in period_ranges:
nn = istop - istart
signal[istart:istop] = madam_signal[offset:offset+nn]
offset += nn
cachename = "{}_{}".format(self._name_out, det)
tod.cache.put(cachename, signal, replace=True)
offset += nlocal
global_offset = offset

return

0 comments on commit 371aa89

Please sign in to comment.