Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 22, 2024
1 parent 9058109 commit bafcb8d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
14 changes: 9 additions & 5 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,14 @@ def s_enhance(self):
model training during high res coarsening and also in forward pass
routine to determine shape of needed exogenous data"""
models = getattr(self, 'models', [self])
s_enhances = [m.meta['s_enhance'] for m in models]
s_enhances = [m.meta.get('s_enhance', None) for m in models]
s_enhance = (
self.get_s_enhance_from_layers()
if any(s is None for s in s_enhances)
else np.prod(s_enhances)
else int(np.prod(s_enhances))
)
if len(models) == 1:
self.meta['s_enhance'] = s_enhance
return s_enhance

@property
Expand All @@ -225,12 +227,14 @@ def t_enhance(self):
model training during high res coarsening and also in forward pass
routine to determine shape of needed exogenous data"""
models = getattr(self, 'models', [self])
t_enhances = [m.meta['t_enhance'] for m in models]
t_enhances = [m.meta.get('t_enhance', None) for m in models]
t_enhance = (
self.get_t_enhance_from_layers()
if any(t is None for t in t_enhances)
else np.prod(t_enhances)
else int(np.prod(t_enhances))
)
if len(models) == 1:
self.meta['t_enhance'] = t_enhance
return t_enhance

@property
Expand Down Expand Up @@ -593,7 +597,7 @@ def save_params(self, out_dir):
fp_params, 'w', encoding=locale.getpreferredencoding(False)
) as f:
params = self.model_params
json.dump(params, f, sort_keys=True, indent=2)
json.dump(params, f, sort_keys=True, indent=2, default=safe_cast)


# pylint: disable=E1101,W0201,E0203
Expand Down
4 changes: 2 additions & 2 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def safe_cast(o):
return str(o)


def safe_serialize(obj):
def safe_serialize(obj, **kwargs):
"""json.dumps with non-serializable object handling."""
return json.dumps(obj, default=safe_cast)
return json.dumps(obj, default=safe_cast, **kwargs)


class Timer:
Expand Down

0 comments on commit bafcb8d

Please sign in to comment.