From ce76db8579c862d5b2dd5d08ca46be41090ce68e Mon Sep 17 00:00:00 2001 From: "Brandon N. Benton" Date: Fri, 30 Aug 2024 20:05:20 -0600 Subject: [PATCH] split up attribute collection for meta and time index. significantly reduced mem use since number of spatial chunks tends to be much lower than number of time chunks. --- sup3r/pipeline/strategy.py | 23 ++- sup3r/postprocessing/collectors/h5.py | 255 +++++++++++++++----------- sup3r/postprocessing/writers/base.py | 2 +- sup3r/postprocessing/writers/h5.py | 2 +- 4 files changed, 163 insertions(+), 119 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 09efaf092..47e5310c5 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -220,10 +220,10 @@ def __post_init__(self): temporal_pad=self.temporal_pad, ) self.n_chunks = self.fwp_slicer.n_chunks + self.out_files = self.get_out_files(out_files=self.out_pattern) self.node_chunks = self._get_node_chunks() if not self.head_node: - self.out_files = self.get_out_files(out_files=self.out_pattern) self.hr_lat_lon = self.get_hr_lat_lon() hr_shape = self.hr_lat_lon.shape[:-1] self.gids = np.arange(np.prod(hr_shape)).reshape(hr_shape) @@ -280,8 +280,21 @@ def _get_node_chunks(self): """Get array of lists such that node_chunks[i] is a list of indices for the ith node indexing the chunks that will be sent through the generator on the ith node.""" - node_chunks = min(self.max_nodes or np.inf, self.n_chunks) - return np.array_split(np.arange(self.n_chunks), node_chunks) + logger.info('Checking for unfinished chunks.') + unfinished_chunks = [ + n + for n in range(self.n_chunks) + if not self.chunk_finished(chunk_idx=n, log=False) + ] + logger.info( + '%s of %s chunks are unfinished.', + len(unfinished_chunks), + self.n_chunks, + ) + node_chunks = min( + self.max_nodes or np.inf, max((1, len(unfinished_chunks))) + ) + return np.array_split(unfinished_chunks, node_chunks) def _get_fwp_chunk_shape(self): """Get fwp_chunk_shape with default shape equal to the input handler @@ -550,14 +563,14 @@ def node_finished(self, node_idx): """Check if all out files for a given node have been saved""" return all(self.chunk_finished(i) for i in self.node_chunks[node_idx]) - def chunk_finished(self, chunk_idx): + def chunk_finished(self, chunk_idx, log=True): """Check if process for given chunk_index has already been run. Considered finished if there is already an output file and incremental is False.""" out_file = self.out_files[chunk_idx] check = os.path.exists(out_file) and self.incremental - if check: + if check and log: logger.info( '%s already exists and incremental = True. Skipping forward ' 'pass for chunk index %s.', diff --git a/sup3r/postprocessing/collectors/h5.py b/sup3r/postprocessing/collectors/h5.py index bc13d06e0..1e3648835 100644 --- a/sup3r/postprocessing/collectors/h5.py +++ b/sup3r/postprocessing/collectors/h5.py @@ -197,21 +197,27 @@ def get_data( logger.error(msg) raise OSError(msg) from e - def _get_file_attrs(self, file): - """Get meta data and time index for a single file""" - if file in self.file_attrs: - meta = self.file_attrs[file]['meta'] - time_index = self.file_attrs[file]['time_index'] - else: - with RexOutputs(file, mode='r') as f: - meta = f.meta - time_index = f.time_index - if file not in self.file_attrs: - self.file_attrs[file] = {'meta': meta, 'time_index': time_index} - logger.debug( - 'Finished getting info for file: %s. %s', file, _mem_check() - ) - return meta, time_index + def _get_file_time_index(self, file): + """Get time index for a single file. Simple method used in thread pool + for attribute collection.""" + with RexOutputs(file, mode='r') as f: + time_index = f.time_index + logger.debug( + 'Finished getting time index for file: %s. %s', + file, + _mem_check(), + ) + return time_index + + def _get_file_meta(self, file): + """Get meta for a single file. Simple method used in thread pool for + attribute collection.""" + with RexOutputs(file, mode='r') as f: + meta = f.meta + logger.debug( + 'Finished getting meta for file: %s. %s', file, _mem_check() + ) + return meta def get_unique_chunk_files(self, file_paths): """We get files for the unique spatial and temporal extents covered by @@ -222,9 +228,12 @@ def get_unique_chunk_files(self, file_paths): Parameters ---------- - file_paths : list | str - Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. + t_files : list + Explicit list of str file paths which, when combined, provide the + entire spatial domain. + s_files : list + Explicit list of str file paths which, when combined, provide the + entire temporal extent. """ t_chunk, s_chunk = self.get_chunk_indices(file_paths[0]) t_files = file_paths[0].replace(f'{t_chunk}_{s_chunk}', f'*_{s_chunk}') @@ -233,11 +242,9 @@ def get_unique_chunk_files(self, file_paths): s_files = file_paths[0].replace(f'{t_chunk}_{s_chunk}', f'{t_chunk}_*') s_files = set(glob(s_files)).intersection(file_paths) logger.info('Found %s unique spatial chunks', len(s_files)) - return list(s_files) + list(t_files) + return list(t_files), list(s_files) - def _get_collection_attrs( - self, file_paths, sort=True, sort_key=None, max_workers=None - ): + def _get_collection_attrs(self, file_paths, max_workers=None): """Get important dataset attributes from a file list to be collected. Assumes the file list is chunked in time (row chunked). @@ -247,20 +254,9 @@ def _get_collection_attrs( file_paths : list | str Explicit list of str file paths that will be sorted and collected or a single string with unix-style /search/patt*ern.h5. - sort : bool - flag to sort flist to determine meta data order. - sort_key : None | fun - Optional sort key to sort flist by (determines how meta is built - if out_file does not exist). max_workers : int | None Number of workers to use in parallel. 1 runs serial, - None will use all available workers. - target_meta_file : str - Path to target final meta containing coordinates to keep from the - full list of coordinates present in the collected meta for the full - file list. - threshold : float - Threshold distance for finding target coordinates within full meta + None uses all available. Returns ------- @@ -271,32 +267,26 @@ def _get_collection_attrs( Concatenated full size meta data from the flist that is being collected or provided target meta """ - if sort: - file_paths = sorted(file_paths, key=sort_key) - - logger.info( - 'Getting collection attrs for full dataset with ' - 'max_workers=%s. %s', max_workers, _mem_check() - ) - time_index = [None] * len(file_paths) - meta = [None] * len(file_paths) - tasks = [dask.delayed(self._get_file_attrs)(fn) for fn in file_paths] + t_files, s_files = self.get_unique_chunk_files(file_paths) + meta_tasks = [dask.delayed(self._get_file_meta)(fn) for fn in s_files] + ti_tasks = [ + dask.delayed(self._get_file_time_index)(fn) for fn in t_files + ] if max_workers == 1: - out = dask.compute(*tasks, scheduler='single-threaded') + meta = dask.compute(*meta_tasks, scheduler='single-threaded') + time_index = dask.compute(*ti_tasks, scheduler='single-threaded') else: - out = dask.compute( - *tasks, scheduler='threads', num_workers=max_workers + meta = dask.compute( + *meta_tasks, scheduler='threads', num_workers=max_workers + ) + time_index = dask.compute( + *ti_tasks, scheduler='threads', num_workers=max_workers ) logger.info( 'Finished getting meta and time_index for all unique chunks.' ) - for i, vals in enumerate(out): - meta[i], time_index[i] = vals - logger.debug( - 'Finished filling arrays for file %s. %s', i, _mem_check() - ) time_index = pd.DatetimeIndex(np.concatenate(time_index)) time_index = time_index.sort_values() time_index = time_index.drop_duplicates() @@ -360,8 +350,6 @@ def get_target_and_masked_meta( def get_collection_attrs( self, file_paths, - sort=True, - sort_key=None, max_workers=None, target_meta_file=None, threshold=1e-4, @@ -375,11 +363,6 @@ def get_collection_attrs( file_paths : list | str Explicit list of str file paths that will be sorted and collected or a single string with unix-style /search/patt*ern.h5. - sort : bool - flag to sort flist to determine meta data order. - sort_key : None | fun - Optional sort key to sort flist by (determines how meta is built - if out_file does not exist). max_workers : int | None Number of workers to use in parallel. 1 runs serial, None will use all available workers. @@ -414,7 +397,7 @@ def get_collection_attrs( assert os.path.exists(target_meta_file), msg time_index, meta = self._get_collection_attrs( - file_paths, sort=sort, sort_key=sort_key, max_workers=max_workers + file_paths, max_workers=max_workers ) logger.info('Getting target and masked meta.') target_meta, masked_meta = self.get_target_and_masked_meta( @@ -649,6 +632,88 @@ def get_flist_chunks(self, file_paths, n_writes=None): ) return flist_chunks + def collect_feature( + self, + dset, + target_masked_meta, + target_meta_file, + time_index, + shape, + flist_chunks, + out_file, + threshold=1e-4, + max_workers=None, + ): + """Collect chunks for single feature + + dset : str + Dataset name to collect. + target_masked_meta : pd.DataFrame + Same as subset_masked_meta but instead for the entire list of files + to be collected. + target_meta_file : str + Path to target final meta containing coordinates to keep from the + full file list collected meta. This can be but is not necessarily a + subset of the full list of coordinates for all files in the file + list. This is used to remove coordinates from the full file list + which are not present in the target_meta. Either this full + meta or a subset, depending on which coordinates are present in + the data to be collected, will be the final meta for the collected + output files. + time_index : pd.datetimeindex + Concatenated datetime index for the given file paths. + shape : tuple + Output (collected) dataset shape + flist_chunks : list + List of file list chunks. Used to split collection and writing into + multiple steps. + out_file : str + File path of final output file. + threshold : float + Threshold distance for finding target coordinates within full meta + max_workers : int | None + Number of workers to use in parallel. 1 runs serial, + None will use all available workers. + """ + logger.debug('Collecting dataset "%s".', dset) + + if len(flist_chunks) == 1: + self._collect_flist( + dset, + target_masked_meta, + time_index, + shape, + flist_chunks[0], + out_file, + target_masked_meta, + max_workers=max_workers, + ) + + else: + for i, flist in enumerate(flist_chunks): + logger.info( + 'Collecting file list chunk %s out of %s ', + i + 1, + len(flist_chunks), + ) + out = self.get_collection_attrs( + flist, + max_workers=max_workers, + target_meta_file=target_meta_file, + threshold=threshold, + ) + time_index, _, masked_meta, shape, _ = out + self._collect_flist( + dset, + masked_meta, + time_index, + shape, + flist, + out_file, + target_masked_meta, + max_workers=max_workers, + ) + @classmethod def collect( cls, @@ -692,8 +757,8 @@ def collect( Job name for status file if running from pipeline. pipeline_step : str, optional Name of the pipeline step being run. If ``None``, the - ``pipeline_step`` will be set to the ``"collect``, - mimicking old reV behavior. By default, ``None``. + ``pipeline_step`` will be set to ``"collect``, mimicking old reV + behavior. By default, ``None``. target_meta_file : str Path to target final meta containing coordinates to keep from the full file list collected meta. This can be but is not necessarily a @@ -734,13 +799,8 @@ def collect( logger.info('overwrite=True, removing %s', out_file) os.remove(out_file) - extent_files = collector.get_unique_chunk_files(collector.flist) - logger.info( - 'Using %s unique chunk files to build time index and meta.', - len(extent_files), - ) out = collector.get_collection_attrs( - extent_files, + collector.flist, max_workers=max_workers, target_meta_file=target_meta_file, threshold=threshold, @@ -749,51 +809,22 @@ def collect( time_index, target_meta, target_masked_meta = out[:3] shape, global_attrs = out[3:] + flist_chunks = collector.get_flist_chunks( + collector.flist, n_writes=n_writes + ) + if not os.path.exists(out_file): + collector._init_h5(out_file, time_index, target_meta, global_attrs) for dset in features: logger.debug('Collecting dataset "%s".', dset) - flist_chunks = collector.get_flist_chunks( - collector.flist, n_writes=n_writes + collector.collect_feature( + dset=dset, + target_masked_meta=target_masked_meta, + target_meta_file=target_meta_file, + time_index=time_index, + shape=shape, + flist_chunks=flist_chunks, + out_file=out_file, + threshold=threshold, + max_workers=max_workers, ) - if not os.path.exists(out_file): - collector._init_h5( - out_file, time_index, target_meta, global_attrs - ) - - if len(flist_chunks) == 1: - collector._collect_flist( - dset, - target_masked_meta, - time_index, - shape, - flist_chunks[0], - out_file, - target_masked_meta, - max_workers=max_workers, - ) - - else: - for i, flist in enumerate(flist_chunks): - logger.info( - 'Collecting file list chunk %s out of %s ', - i + 1, - len(flist_chunks), - ) - out = collector.get_collection_attrs( - flist, - max_workers=max_workers, - target_meta_file=target_meta_file, - threshold=threshold, - ) - time_index, target_meta, masked_meta, shape, _ = out - collector._collect_flist( - dset, - masked_meta, - time_index, - shape, - flist, - out_file, - target_masked_meta, - max_workers=max_workers, - ) - logger.info('Finished file collection.') diff --git a/sup3r/postprocessing/writers/base.py b/sup3r/postprocessing/writers/base.py index d7825cdb1..7d39d1e54 100644 --- a/sup3r/postprocessing/writers/base.py +++ b/sup3r/postprocessing/writers/base.py @@ -381,7 +381,7 @@ def enforce_limits(features, data): mins.append(min_val) data = np.maximum(data, mins) - return np.minimum(data, maxes) + return np.minimum(data, maxes).astype(np.float32) @staticmethod def pad_lat_lon(lat_lon): diff --git a/sup3r/postprocessing/writers/h5.py b/sup3r/postprocessing/writers/h5.py index 8a7c27a08..3a5dedd7f 100644 --- a/sup3r/postprocessing/writers/h5.py +++ b/sup3r/postprocessing/writers/h5.py @@ -149,7 +149,7 @@ def _transform_output(cls, data, features, lat_lon, max_workers=None): data, features, lat_lon, max_workers=max_workers ) features = cls.get_renamed_features(features) - data = cls.enforce_limits(features, data) + data = cls.enforce_limits(features=features, data=data) return data, features @classmethod