diff --git a/orsopy/fileio/base.py b/orsopy/fileio/base.py index 8e3db140..3514e2e7 100644 --- a/orsopy/fileio/base.py +++ b/orsopy/fileio/base.py @@ -431,6 +431,9 @@ def represent_data(self, data): elif isinstance(data, datetime.datetime): value = data.isoformat("T") return super().represent_scalar("tag:yaml.org,2002:timestamp", value) + elif np.isscalar(data) and hasattr(data, "item"): + # If data is a numpy scalar, convert to a python object + return super().represent_data(data.item()) else: return super().represent_data(data) diff --git a/tests/test_fileio/test_orso.py b/tests/test_fileio/test_orso.py index 5199c998..417e98d8 100644 --- a/tests/test_fileio/test_orso.py +++ b/tests/test_fileio/test_orso.py @@ -230,6 +230,17 @@ def test_extra_elements(self): info = datasets[0].info assert hasattr(info.data_source.measurement.instrument_settings.incident_angle, "resolution") + def test_save_numpy_scalar_dtypes(self): + info = fileio.Orso.empty() + info.data_source.measurement.instrument_settings.wavelength = Value(np.float64(10.0)) + info.data_source.measurement.instrument_settings.incident_angle = Value(np.int32(2)) + ds = fileio.orso.OrsoDataset(info, np.arange(20.).reshape(10, 2)) + fileio.save_orso([ds], "test_numpy.ort") + ls = fileio.load_orso("test_numpy.ort") + i_s = ls[0].info.data_source.measurement.instrument_settings + assert i_s.wavelength.magnitude == 10.0 + assert i_s.incident_angle.magnitude == 2 + class TestFunctions(unittest.TestCase): """