Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add date support #45

Merged
merged 1 commit into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions dataclass_csv/dataclass_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import csv

from datetime import datetime
from datetime import date, datetime
from distutils.util import strtobool
from typing import Union, Type, Optional, Sequence, Dict, Any, List

Expand Down Expand Up @@ -157,7 +157,7 @@ def _get_value(self, row, field):
else:
return value

def _parse_date_value(self, field, date_value):
def _parse_date_value(self, field, date_value, field_type):
dateformat = self._get_metadata_option(field, "dateformat")

if not isinstance(date_value, str):
Expand All @@ -175,7 +175,13 @@ def _parse_date_value(self, field, date_value):
"{'dateformat': <date_format>})`."
)
)
return datetime.strptime(date_value, dateformat)

datetime_obj = datetime.strptime(date_value, dateformat)

if field_type == date:
return datetime_obj.date()
else:
return datetime_obj

def _process_row(self, row):
values = []
Expand All @@ -200,9 +206,9 @@ def _process_row(self, row):
if len(type_args) == 1:
field_type = type_args[0]

if field_type is datetime:
if field_type is datetime or field_type is date:
try:
transformed_value = self._parse_date_value(field, value)
transformed_value = self._parse_date_value(field, value, field_type)
except ValueError as ex:
raise CsvValueError(ex, line_number=self._reader.line_num) from None
else:
Expand Down
9 changes: 8 additions & 1 deletion tests/mocks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import re

from datetime import datetime
from datetime import date, datetime

from dataclass_csv import dateformat, accept_whitespaces

Expand Down Expand Up @@ -52,6 +52,13 @@ class UserWithDateFormatDecorator:
create_date: datetime


@dateformat("%Y-%m-%d")
@dataclasses.dataclass
class UserWithDateFormatDecoratorAndDateField:
name: str
create_date: date


@dataclasses.dataclass
class UserWithDateFormatMetadata:
name: str
Expand Down
14 changes: 13 additions & 1 deletion tests/test_dataclass_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import dataclasses

from datetime import datetime
from datetime import date, datetime
from dataclass_csv import DataclassReader, CsvValueError

from .mocks import (
Expand All @@ -12,6 +12,7 @@
UserWithInitFalse,
UserWithInitFalseAndDefaultValue,
UserWithDefaultDatetimeField,
UserWithDateFormatDecoratorAndDateField,
UserWithSSN,
SSN,
UserWithEmail,
Expand Down Expand Up @@ -182,6 +183,17 @@ def test_reader_with_datetime_default_value(create_csv):
assert isinstance(items[0].birthday, datetime)


def test_reader_with_date(create_csv):
csv_file = create_csv({"name": "User", "create_date": "2019-01-01"})

with csv_file.open() as f:
reader = DataclassReader(f, UserWithDateFormatDecoratorAndDateField)
items = list(reader)
assert len(items) > 0
assert isinstance(items[0].create_date, date)
assert items[0].create_date == date(2019, 1, 1)


def test_should_parse_user_defined_types(create_csv):
csv_file = create_csv(
[
Expand Down