Skip to content

Commit

Permalink
compute node_chunks without masked chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Sep 25, 2024
1 parent c4b8089 commit a1a6038
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 29 deletions.
1 change: 1 addition & 0 deletions sup3r/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def solar(ctx, verbose):
{
"fp_pattern": "./chunks/sup3r*.h5",
"nsrdb_fp": "/datasets/NSRDB/current/nsrdb_2015.h5",
"max_nodes": 100,
"execution_control": {
"option": "kestrel",
"walltime": 4,
Expand Down
15 changes: 10 additions & 5 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import optimizers

from sup3r.preprocessing.utilities import get_class_kwargs
from sup3r.utilities import VERSION_RECORD

from .abstract import AbstractInterface, AbstractSingleModel
from .utilities import get_optimizer_class

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -332,14 +333,18 @@ def update_optimizer(self, option='generator', **kwargs):
if 'gen' in option.lower() or 'all' in option.lower():
conf = self.get_optimizer_config(self.optimizer)
conf.update(**kwargs)
optimizer_class = getattr(optimizers, conf['name'])
self._optimizer = optimizer_class.from_config(conf)
optimizer_class = get_optimizer_class(conf)
self._optimizer = optimizer_class.from_config(
get_class_kwargs(optimizer_class, conf)
)

if 'disc' in option.lower() or 'all' in option.lower():
conf = self.get_optimizer_config(self.optimizer_disc)
conf.update(**kwargs)
optimizer_class = getattr(optimizers, conf['name'])
self._optimizer_disc = optimizer_class.from_config(conf)
optimizer_class = get_optimizer_class(conf)
self._optimizer_disc = optimizer_class.from_config(
get_class_kwargs(optimizer_class, conf)
)

@property
def meta(self):
Expand Down
4 changes: 2 additions & 2 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def _run_serial(cls, strategy, node_index):
fwp = cls(strategy, node_index=node_index)
for i, chunk_index in enumerate(strategy.node_chunks[node_index]):
now = dt.now()
if not strategy.chunk_skippable(chunk_index):
if not strategy.chunk_finished(chunk_index):
chunk = fwp.get_input_chunk(chunk_index=chunk_index)
failed, _ = cls.run_chunk(
chunk=chunk,
Expand Down Expand Up @@ -516,7 +516,7 @@ def _run_parallel(cls, strategy, node_index):
with SpawnProcessPool(**pool_kws) as exe:
now = dt.now()
for _, chunk_index in enumerate(strategy.node_chunks[node_index]):
if not strategy.chunk_skippable(chunk_index):
if not strategy.chunk_finished(chunk_index):
chunk = fwp.get_input_chunk(chunk_index=chunk_index)
fut = exe.submit(
fwp.run_chunk,
Expand Down
29 changes: 8 additions & 21 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,28 +284,21 @@ def node_chunks(self):
"""Get array of lists such that node_chunks[i] is a list of
indices for the chunks that will be sent through the generator on the
ith node."""
node_chunks = min(self.max_nodes or np.inf, self.n_chunks)
return np.array_split(np.arange(self.n_chunks), node_chunks)
node_chunks = min(
self.max_nodes or np.inf, len(self.unmasked_chunks)
)
return np.array_split(np.arange(self.unmasked_chunks), node_chunks)

@property
def unfinished_chunks(self):
"""List of chunk indices that have not yet been written and are not
masked."""
def unmasked_chunks(self):
"""List of chunk indices that are not masked from the input spatial
region."""
return [
idx
for idx in np.arange(self.n_chunks)
if not self.chunk_skippable(idx, log=False)
if not self.chunk_masked(idx, log=False)
]

@property
def unfinished_node_chunks(self):
"""Get node_chunks lists which only include indices for chunks which
have not yet been written or are not masked."""
node_chunks = min(
self.max_nodes or np.inf, len(self.unfinished_chunks)
)
return np.array_split(self.unfinished_chunks, node_chunks)

def _get_fwp_chunk_shape(self):
"""Get fwp_chunk_shape with default shape equal to the input handler
shape"""
Expand Down Expand Up @@ -627,9 +620,3 @@ def chunk_masked(self, chunk_idx, log=True):
s_chunk_idx,
)
return mask_check

def chunk_skippable(self, chunk_idx, log=True):
"""Check if chunk is already written or masked."""
return self.chunk_masked(
chunk_idx=chunk_idx, log=log
) or self.chunk_finished(chunk_idx=chunk_idx, log=log)
2 changes: 1 addition & 1 deletion sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def xr_open_mfdataset(files, **kwargs):
msg = 'Could not use xr.open_mfdataset to open %s. %s'
if len(files) == 1:
raise RuntimeError(msg % (files, e)) from e
msg += 'Trying to open them separately and merge. %s'
msg += 'Trying to open them separately and merge.'
logger.warning(msg, files, e)
warn(msg % (files, e))
return merge_datasets(files, **default_kwargs)
Expand Down

0 comments on commit a1a6038

Please sign in to comment.