diff --git a/dataclass_csv/dataclass_reader.py b/dataclass_csv/dataclass_reader.py index f36404e..2dc89c8 100644 --- a/dataclass_csv/dataclass_reader.py +++ b/dataclass_csv/dataclass_reader.py @@ -1,7 +1,7 @@ import dataclasses import csv -from datetime import datetime +from datetime import date, datetime from distutils.util import strtobool from typing import Union, Type, Optional, Sequence, Dict, Any, List @@ -157,7 +157,7 @@ def _get_value(self, row, field): else: return value - def _parse_date_value(self, field, date_value): + def _parse_date_value(self, field, date_value, field_type): dateformat = self._get_metadata_option(field, "dateformat") if not isinstance(date_value, str): @@ -175,7 +175,13 @@ def _parse_date_value(self, field, date_value): "{'dateformat': })`." ) ) - return datetime.strptime(date_value, dateformat) + + datetime_obj = datetime.strptime(date_value, dateformat) + + if field_type == date: + return datetime_obj.date() + else: + return datetime_obj def _process_row(self, row): values = [] @@ -200,9 +206,9 @@ def _process_row(self, row): if len(type_args) == 1: field_type = type_args[0] - if field_type is datetime: + if field_type is datetime or field_type is date: try: - transformed_value = self._parse_date_value(field, value) + transformed_value = self._parse_date_value(field, value, field_type) except ValueError as ex: raise CsvValueError(ex, line_number=self._reader.line_num) from None else: diff --git a/tests/mocks.py b/tests/mocks.py index 3ea242f..682c08c 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,7 +1,7 @@ import dataclasses import re -from datetime import datetime +from datetime import date, datetime from dataclass_csv import dateformat, accept_whitespaces @@ -52,6 +52,13 @@ class UserWithDateFormatDecorator: create_date: datetime +@dateformat("%Y-%m-%d") +@dataclasses.dataclass +class UserWithDateFormatDecoratorAndDateField: + name: str + create_date: date + + @dataclasses.dataclass class UserWithDateFormatMetadata: name: str diff --git a/tests/test_dataclass_reader.py b/tests/test_dataclass_reader.py index 1c53c80..5a08d3f 100644 --- a/tests/test_dataclass_reader.py +++ b/tests/test_dataclass_reader.py @@ -1,7 +1,7 @@ import pytest import dataclasses -from datetime import datetime +from datetime import date, datetime from dataclass_csv import DataclassReader, CsvValueError from .mocks import ( @@ -12,6 +12,7 @@ UserWithInitFalse, UserWithInitFalseAndDefaultValue, UserWithDefaultDatetimeField, + UserWithDateFormatDecoratorAndDateField, UserWithSSN, SSN, UserWithEmail, @@ -182,6 +183,17 @@ def test_reader_with_datetime_default_value(create_csv): assert isinstance(items[0].birthday, datetime) +def test_reader_with_date(create_csv): + csv_file = create_csv({"name": "User", "create_date": "2019-01-01"}) + + with csv_file.open() as f: + reader = DataclassReader(f, UserWithDateFormatDecoratorAndDateField) + items = list(reader) + assert len(items) > 0 + assert isinstance(items[0].create_date, date) + assert items[0].create_date == date(2019, 1, 1) + + def test_should_parse_user_defined_types(create_csv): csv_file = create_csv( [