Skip to content

Commit

Permalink
split up attribute collection for meta and time index. significantly …
Browse files Browse the repository at this point in the history
…reduced mem use since number of spatial chunks tends to be much lower than number of time chunks.
  • Loading branch information
bnb32 committed Aug 31, 2024
1 parent 0d198b6 commit ce76db8
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 119 deletions.
23 changes: 18 additions & 5 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ def __post_init__(self):
temporal_pad=self.temporal_pad,
)
self.n_chunks = self.fwp_slicer.n_chunks
self.out_files = self.get_out_files(out_files=self.out_pattern)
self.node_chunks = self._get_node_chunks()

if not self.head_node:
self.out_files = self.get_out_files(out_files=self.out_pattern)
self.hr_lat_lon = self.get_hr_lat_lon()
hr_shape = self.hr_lat_lon.shape[:-1]
self.gids = np.arange(np.prod(hr_shape)).reshape(hr_shape)
Expand Down Expand Up @@ -280,8 +280,21 @@ def _get_node_chunks(self):
"""Get array of lists such that node_chunks[i] is a list of
indices for the ith node indexing the chunks that will be sent through
the generator on the ith node."""
node_chunks = min(self.max_nodes or np.inf, self.n_chunks)
return np.array_split(np.arange(self.n_chunks), node_chunks)
logger.info('Checking for unfinished chunks.')
unfinished_chunks = [
n
for n in range(self.n_chunks)
if not self.chunk_finished(chunk_idx=n, log=False)
]
logger.info(
'%s of %s chunks are unfinished.',
len(unfinished_chunks),
self.n_chunks,
)
node_chunks = min(
self.max_nodes or np.inf, max((1, len(unfinished_chunks)))
)
return np.array_split(unfinished_chunks, node_chunks)

def _get_fwp_chunk_shape(self):
"""Get fwp_chunk_shape with default shape equal to the input handler
Expand Down Expand Up @@ -550,14 +563,14 @@ def node_finished(self, node_idx):
"""Check if all out files for a given node have been saved"""
return all(self.chunk_finished(i) for i in self.node_chunks[node_idx])

def chunk_finished(self, chunk_idx):
def chunk_finished(self, chunk_idx, log=True):
"""Check if process for given chunk_index has already been run.
Considered finished if there is already an output file and incremental
is False."""

out_file = self.out_files[chunk_idx]
check = os.path.exists(out_file) and self.incremental
if check:
if check and log:
logger.info(
'%s already exists and incremental = True. Skipping forward '
'pass for chunk index %s.',
Expand Down
Loading

0 comments on commit ce76db8

Please sign in to comment.