Skip to content

Commit

Permalink
dont need separate status fle write. this is handled in get_node_cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 30, 2024
1 parent a7c4d1d commit 0d198b6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 46 deletions.
46 changes: 17 additions & 29 deletions sup3r/postprocessing/collectors/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import logging
import os
import time
from glob import glob
from warnings import warn

import dask
import numpy as np
import pandas as pd
from gaps import Status
from rex.utilities.loggers import init_logger
from scipy.spatial import KDTree

Expand Down Expand Up @@ -210,7 +208,9 @@ def _get_file_attrs(self, file):
time_index = f.time_index
if file not in self.file_attrs:
self.file_attrs[file] = {'meta': meta, 'time_index': time_index}
logger.debug('Finished getting info for file: %s', file)
logger.debug(
'Finished getting info for file: %s. %s', file, _mem_check()
)
return meta, time_index

def get_unique_chunk_files(self, file_paths):
Expand All @@ -228,12 +228,12 @@ def get_unique_chunk_files(self, file_paths):
"""
t_chunk, s_chunk = self.get_chunk_indices(file_paths[0])
t_files = file_paths[0].replace(f'{t_chunk}_{s_chunk}', f'*_{s_chunk}')
t_files = glob(t_files)
t_files = set(glob(t_files)).intersection(file_paths)
logger.info('Found %s unique temporal chunks', len(t_files))
s_files = file_paths[0].replace(f'{t_chunk}_{s_chunk}', f'{t_chunk}_*')
s_files = glob(s_files)
s_files = set(glob(s_files)).intersection(file_paths)
logger.info('Found %s unique spatial chunks', len(s_files))
return s_files + t_files
return list(s_files) + list(t_files)

def _get_collection_attrs(
self, file_paths, sort=True, sort_key=None, max_workers=None
Expand Down Expand Up @@ -276,7 +276,7 @@ def _get_collection_attrs(

logger.info(
'Getting collection attrs for full dataset with '
f'max_workers={max_workers}.'
'max_workers=%s. %s', max_workers, _mem_check()
)

time_index = [None] * len(file_paths)
Expand All @@ -289,8 +289,14 @@ def _get_collection_attrs(
out = dask.compute(
*tasks, scheduler='threads', num_workers=max_workers
)
logger.info(
'Finished getting meta and time_index for all unique chunks.'
)
for i, vals in enumerate(out):
meta[i], time_index[i] = vals
logger.debug(
'Finished filling arrays for file %s. %s', i, _mem_check()
)
time_index = pd.DatetimeIndex(np.concatenate(time_index))
time_index = time_index.sort_values()
time_index = time_index.drop_duplicates()
Expand All @@ -300,6 +306,7 @@ def _get_collection_attrs(
meta = meta.drop_duplicates(subset=['latitude', 'longitude'])
meta = meta.sort_values('gid')

logger.info('Finished building full meta and time index.')
return time_index, meta

def get_target_and_masked_meta(
Expand Down Expand Up @@ -403,21 +410,20 @@ def get_collection_attrs(
"""
logger.info(f'Using target_meta_file={target_meta_file}')
if isinstance(target_meta_file, str):
msg = (
f'Provided target meta ({target_meta_file}) does not ' 'exist.'
)
msg = f'Provided target meta ({target_meta_file}) does not exist.'
assert os.path.exists(target_meta_file), msg

time_index, meta = self._get_collection_attrs(
file_paths, sort=sort, sort_key=sort_key, max_workers=max_workers
)

logger.info('Getting target and masked meta.')
target_meta, masked_meta = self.get_target_and_masked_meta(
meta, target_meta_file, threshold=threshold
)

shape = (len(time_index), len(target_meta))

logger.info('Getting global attrs from %s', file_paths[0])
with RexOutputs(file_paths[0], mode='r') as fin:
global_attrs = fin.global_attrs

Expand Down Expand Up @@ -652,9 +658,6 @@ def collect(
max_workers=None,
log_level=None,
log_file=None,
write_status=False,
job_name=None,
pipeline_step=None,
target_meta_file=None,
n_writes=None,
overwrite=True,
Expand Down Expand Up @@ -709,8 +712,6 @@ def collect(
threshold : float
Threshold distance for finding target coordinates within full meta
"""
t0 = time.time()

logger.info(
'Initializing collection for file_paths=%s with max_workers=%s',
file_paths,
Expand Down Expand Up @@ -795,17 +796,4 @@ def collect(
max_workers=max_workers,
)

if write_status and job_name is not None:
status = {
'out_dir': os.path.dirname(out_file),
'fout': out_file,
'flist': collector.flist,
'job_status': 'successful',
'runtime': (time.time() - t0) / 60,
}
pipeline_step = pipeline_step or 'collect'
Status.make_single_job_file(
os.path.dirname(out_file), pipeline_step, job_name, status
)

logger.info('Finished file collection.')
17 changes: 0 additions & 17 deletions sup3r/postprocessing/collectors/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

import logging
import os
import time

from gaps import Status
from rex.utilities.loggers import init_logger

from sup3r.preprocessing.cachers import Cacher
Expand All @@ -29,8 +27,6 @@ def collect(
features='all',
log_level=None,
log_file=None,
write_status=False,
job_name=None,
overwrite=True,
res_kwargs=None,
):
Expand Down Expand Up @@ -62,8 +58,6 @@ def collect(
res_kwargs : dict | None
Dictionary of kwargs to pass to xarray.open_mfdataset.
"""
t0 = time.time()

logger.info(f'Initializing collection for file_paths={file_paths}')

if log_level is not None:
Expand All @@ -88,17 +82,6 @@ def collect(
out = xr_open_mfdataset(collector.flist, **res_kwargs)
Cacher.write_netcdf(tmp_file, data=out, features=features)

if write_status and job_name is not None:
status = {
'out_dir': os.path.dirname(out_file),
'fout': out_file,
'flist': collector.flist,
'job_status': 'successful',
'runtime': (time.time() - t0) / 60,
}
Status.make_single_job_file(
os.path.dirname(out_file), 'collect', job_name, status
)
os.replace(tmp_file, out_file)
logger.info('Moved %s to %s.', tmp_file, out_file)

Expand Down

0 comments on commit 0d198b6

Please sign in to comment.