diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 3ffc7098b4..4b48fa89f7 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -211,12 +211,14 @@ def s_enhance(self): model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" models = getattr(self, 'models', [self]) - s_enhances = [m.meta['s_enhance'] for m in models] + s_enhances = [m.meta.get('s_enhance', None) for m in models] s_enhance = ( self.get_s_enhance_from_layers() if any(s is None for s in s_enhances) - else np.prod(s_enhances) + else int(np.prod(s_enhances)) ) + if len(models) == 1: + self.meta['s_enhance'] = s_enhance return s_enhance @property @@ -225,12 +227,14 @@ def t_enhance(self): model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data""" models = getattr(self, 'models', [self]) - t_enhances = [m.meta['t_enhance'] for m in models] + t_enhances = [m.meta.get('t_enhance', None) for m in models] t_enhance = ( self.get_t_enhance_from_layers() if any(t is None for t in t_enhances) - else np.prod(t_enhances) + else int(np.prod(t_enhances)) ) + if len(models) == 1: + self.meta['t_enhance'] = t_enhance return t_enhance @property @@ -593,7 +597,7 @@ def save_params(self, out_dir): fp_params, 'w', encoding=locale.getpreferredencoding(False) ) as f: params = self.model_params - json.dump(params, f, sort_keys=True, indent=2) + json.dump(params, f, sort_keys=True, indent=2, default=safe_cast) # pylint: disable=E1101,W0201,E0203 diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 6d8bf4c512..3d7bf0a873 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -38,9 +38,9 @@ def safe_cast(o): return str(o) -def safe_serialize(obj): +def safe_serialize(obj, **kwargs): """json.dumps with non-serializable object handling.""" - return json.dumps(obj, default=safe_cast) + return json.dumps(obj, default=safe_cast, **kwargs) class Timer: