Skip to content

Commit

Permalink
Merge pull request #101 from nvaytet/numpy-dtypes
Browse files Browse the repository at this point in the history
Convert numpy scalars to native python objects before sending to yaml dump
  • Loading branch information
andyfaff authored Apr 24, 2023
2 parents 1f772dd + 1c2d007 commit 4623a87
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
3 changes: 3 additions & 0 deletions orsopy/fileio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,9 @@ def represent_data(self, data):
elif isinstance(data, datetime.datetime):
value = data.isoformat("T")
return super().represent_scalar("tag:yaml.org,2002:timestamp", value)
elif np.isscalar(data) and hasattr(data, "item"):
# If data is a numpy scalar, convert to a python object
return super().represent_data(data.item())
else:
return super().represent_data(data)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_fileio/test_orso.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,17 @@ def test_extra_elements(self):
info = datasets[0].info
assert hasattr(info.data_source.measurement.instrument_settings.incident_angle, "resolution")

def test_save_numpy_scalar_dtypes(self):
info = fileio.Orso.empty()
info.data_source.measurement.instrument_settings.wavelength = Value(np.float64(10.0))
info.data_source.measurement.instrument_settings.incident_angle = Value(np.int32(2))
ds = fileio.orso.OrsoDataset(info, np.arange(20.).reshape(10, 2))
fileio.save_orso([ds], "test_numpy.ort")
ls = fileio.load_orso("test_numpy.ort")
i_s = ls[0].info.data_source.measurement.instrument_settings
assert i_s.wavelength.magnitude == 10.0
assert i_s.incident_angle.magnitude == 2


class TestFunctions(unittest.TestCase):
"""
Expand Down

0 comments on commit 4623a87

Please sign in to comment.