Skip to content

Commit

Permalink
feat: support beancount 3
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhantgoel committed Jul 3, 2024
1 parent 5d5f287 commit 5845dcf
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 39 deletions.
42 changes: 21 additions & 21 deletions beancount_n26/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from datetime import datetime
from typing import Mapping, Tuple, Dict, List, Optional

from beancount.core import data
from beancount.core import data, flags
from beancount.core.amount import Amount
from beancount.core.number import Decimal
from beancount.ingest import importer
from beancount.core.position import CostSpec
from beangulp.importer import Importer

HEADER_FIELDS = {
"en": OrderedDict(
Expand Down Expand Up @@ -119,18 +119,18 @@ class InvalidFormatError(Exception):
PayeePattern = namedtuple("PayeePattern", ["regex", "account"])


class N26Importer(importer.ImporterProtocol):
class N26Importer(Importer):
def __init__(
self,
iban: str,
account: str,
account_name: str,
language: str = "en",
file_encoding: str = "utf-8",
account_patterns: Dict[str, List[str]] = {},
exchange_fees_account: Optional[str] = None,
):
self.iban = iban
self.account = account
self.account_name = account_name
self.language = language
self.file_encoding = file_encoding
self.payee_patterns = set()
Expand Down Expand Up @@ -161,6 +161,9 @@ def __init__(
)
)

def account(self) -> data.Account:
return data.Account(self.account_name)

def _translate(self, key):
return self._translation_strings[key]

Expand All @@ -170,16 +173,13 @@ def _parse_date(self, entry, key="date"):
def name(self):
return "N26 {}".format(self.__class__.__name__)

def file_account(self, _):
return self.account

def file_date(self, file_):
if not self.identify(file_):
def date(self, filepath: str) -> Optional[datetime.date]:
if not self.identify(filepath):
return None

date = None

with open(file_.name, encoding=self.file_encoding) as fd:
with open(filepath, encoding=self.file_encoding) as fd:
reader = csv.DictReader(
fd, delimiter=",", quoting=csv.QUOTE_MINIMAL, quotechar='"'
)
Expand Down Expand Up @@ -207,19 +207,19 @@ def is_valid_header(self, line: str) -> bool:

return True

def identify(self, file_) -> bool:
def identify(self, filepath: str) -> bool:
try:
with open(file_.name, encoding=self.file_encoding) as fd:
with open(filepath, encoding=self.file_encoding) as fd:
line = fd.readline().strip()
except ValueError:
return False
else:
return self.is_valid_header(line)

def extract(self, file_, existing_entries=None):
def extract(self, filepath: str, existing: data.Entries = None) -> data.Entries:
entries = []

if not self.identify(file_):
if not self.identify(filepath):
return []

s_amount_eur = self._translate("amount_eur")
Expand All @@ -229,13 +229,13 @@ def extract(self, file_, existing_entries=None):
s_type_foreign_currency = self._translate("type_foreign_currency")
s_exchange_rate = self._translate("exchange_rate")

with open(file_.name, encoding=self.file_encoding) as fd:
with open(filepath, encoding=self.file_encoding) as fd:
reader = csv.DictReader(
fd, delimiter=",", quoting=csv.QUOTE_MINIMAL, quotechar='"'
)

for index, line in enumerate(reader):
meta = data.new_metadata(file_.name, index)
meta = data.new_metadata(filepath, index)

postings = []

Expand All @@ -254,7 +254,7 @@ def extract(self, file_, existing_entries=None):

postings += [
data.Posting(
self.account,
self.account(),
Amount(-fees, "EUR"),
None,
None,
Expand All @@ -273,7 +273,7 @@ def extract(self, file_, existing_entries=None):

postings += [
data.Posting(
self.account,
self.account(),
Amount(amount_eur - fees, "EUR"),
CostSpec(exchange_rate, None, currency, None, None, None),
None,
Expand All @@ -286,7 +286,7 @@ def extract(self, file_, existing_entries=None):

postings += [
data.Posting(
self.account,
self.account(),
Amount(amount, "EUR"),
None,
None,
Expand Down Expand Up @@ -315,7 +315,7 @@ def extract(self, file_, existing_entries=None):
data.Transaction(
meta,
self._parse_date(line),
self.FLAG,
flags.FLAG_OKAY,
line[s_payee],
line[s_payment_reference],
data.EMPTY_SET,
Expand Down
29 changes: 11 additions & 18 deletions tests/test_n26_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def test_identify_with_optional(importer, filename):
)
)

with open(filename) as fd:
assert importer.identify(fd)
assert importer.identify(filename)


def test_identify_correct_no_optional(importer, filename):
Expand All @@ -87,8 +86,7 @@ def test_identify_correct_no_optional(importer, filename):
)
)

with open(filename) as fd:
assert importer.identify(fd)
assert importer.identify(filename)


def test_extract_no_transactions(importer, filename):
Expand All @@ -102,8 +100,7 @@ def test_extract_no_transactions(importer, filename):
)
)

with open(filename) as fd:
transactions = importer.extract(fd)
transactions = importer.extract(filename)

assert len(transactions) == 0

Expand Down Expand Up @@ -145,9 +142,8 @@ def test_extract_single_transaction(importer, filename):
)
)

with open(filename) as fd:
transactions = importer.extract(fd)
date = importer.file_date(fd)
transactions = importer.extract(filename)
date = importer.date(filename)

assert date == datetime.date(2019, 10, 10)

Expand Down Expand Up @@ -177,9 +173,8 @@ def test_extract_multiple_transactions(importer, filename):
)
)

with open(filename) as fd:
transactions = importer.extract(fd)
date = importer.file_date(fd)
transactions = importer.extract(filename)
date = importer.date(filename)

assert date == datetime.date(2020, 1, 5)
assert len(transactions) == 3
Expand Down Expand Up @@ -231,9 +226,8 @@ def test_extract_multiple_transactions_with_classification(
)
)

with open(filename) as fd:
transactions = importer_with_classification.extract(fd)
date = importer_with_classification.file_date(fd)
transactions = importer_with_classification.extract(filename)
date = importer_with_classification.date(filename)

assert date == datetime.date(2020, 1, 5)
assert len(transactions) == 3
Expand Down Expand Up @@ -302,9 +296,8 @@ def test_extract_conversion(importer, filename):
)
)

with open(filename) as fd:
transactions = importer.extract(fd)
date = importer.file_date(fd)
transactions = importer.extract(filename)
date = importer.date(filename)

assert date == datetime.date(2022, 8, 4)
assert len(transactions) == 4
Expand Down

0 comments on commit 5845dcf

Please sign in to comment.