diff --git a/src/synthesized_datasets/_datasets.py b/src/synthesized_datasets/_datasets.py index 3057bb4..0a4f06c 100644 --- a/src/synthesized_datasets/_datasets.py +++ b/src/synthesized_datasets/_datasets.py @@ -74,22 +74,19 @@ def load(self) -> _pd.DataFrame: else: # CSV load is the default dtypes = { - col: ( - _PD_DTYPE_MAP[dtype] - if dtype - not in [ - _DType.DATETIME, - _DType.TIMEDELTA, - _DType.DATE, - _DType.TIME, - ] - else "string" - ) + col: _PD_DTYPE_MAP[dtype] for col, dtype in self._schema.items() + if dtype + not in [ + _DType.DATETIME, + _DType.TIMEDELTA, + _DType.DATE, + _DType.TIME, + ] } df = _pd.read_csv(self.url, dtype=dtypes) for col, dtype in self._schema.items(): - if dtype is [_DType.DATETIME, _DType.DATE]: + if dtype in [_DType.DATETIME, _DType.DATE]: df[col] = _pd.to_datetime(df[col], dayfirst=True) if dtype in [_DType.TIMEDELTA, _DType.TIME]: df[col] = _pd.to_timedelta(df[col])