Skip to content

Commit

Permalink
Dataframe includes non-dimension columns. (#50)
Browse files Browse the repository at this point in the history
This fixes an inconsistency between the actual columns in the dataframe
and the df meta.
  • Loading branch information
alxmrs authored Mar 24, 2024
1 parent 36b0317 commit 11dc7bf
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions xarray_sql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Row = t.List[t.Any]


# deprecated
def get_columns(ds: xr.Dataset) -> t.List[str]:
return list(ds.dims.keys()) + list(ds.data_vars.keys())

Expand Down
2 changes: 1 addition & 1 deletion xarray_sql/df.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def pivot(b: Block) -> pd.DataFrame:
f'{"_".join(list(ds.data_vars.keys()))}'
)

columns = core.get_columns(ds)
columns = pivot(blocks[0]).columns

# TODO(#18): Is it possible to pass the length (known now) here?
meta = {c: ds[c].dtype for c in columns}
Expand Down
37 changes: 37 additions & 0 deletions xarray_sql/df_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,37 @@

import dask.dataframe as dd
import numpy as np
import pandas as pd
import xarray as xr

from .df import explode, read_xarray, block_slices


def rand_wx(start: str, end: str) -> xr.Dataset:
np.random.seed(42)
lat = np.linspace(-90, 90, num=720)
lon = np.linspace(-180, 180, num=1440)
time = pd.date_range(start, end, freq='H')
level = np.array([1000, 500], dtype=np.int32)
reference_time = pd.Timestamp(start)
temperature = 15 + 8 * np.random.randn(720, 1440, len(time), len(level))
precipitation = 10 * np.random.rand(720, 1440, len(time), len(level))
return xr.Dataset(
data_vars=dict(
temperature=(['lat', 'lon', 'time', 'level'], temperature),
precipitation=(['lat', 'lon', 'time', 'level'], precipitation),
),
coords=dict(
lat=lat,
lon=lon,
time=time,
level=level,
reference_time=reference_time,
),
attrs=dict(description='Random weather.'),
)


class DaskTestCase(unittest.TestCase):

def setUp(self) -> None:
Expand All @@ -18,6 +44,7 @@ def setUp(self) -> None:
self.air_small = self.air.isel(
time=slice(0, 12), lat=slice(0, 11), lon=slice(0, 10)
).chunk(self.chunks)
self.randwx = rand_wx('1995-01-13T00', '1995-01-13T01')


class ExplodeTest(DaskTestCase):
Expand Down Expand Up @@ -84,6 +111,16 @@ def test_chunk_perf(self):
self.assertIsNotNone(df)
self.assertEqual(len(df), np.prod(list(self.air.dims.values())))

def test_column_metadata_preserved(self):
try:
_ = read_xarray(self.randwx, chunks=dict(time=24)).compute()
except ValueError as e:
if (
'The columns in the computed data do not match the columns in the'
' provided metadata' in str(e)
):
self.fail('Column metadata is incorrect.')


if __name__ == '__main__':
unittest.main()

0 comments on commit 11dc7bf

Please sign in to comment.