Skip to content

Commit

Permalink
MAINT: Allow opening PdfReader as contextmanager
Browse files Browse the repository at this point in the history
To mirror PdfWriter, also hints towards file pointer management now that
we keep files open sometimes.
  • Loading branch information
mjsir911 committed May 17, 2024
1 parent 51fbfa3 commit 44a828c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
21 changes: 21 additions & 0 deletions pypdf/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from datetime import datetime
from io import BytesIO, FileIO, UnsupportedOperation
from pathlib import Path
from types import TracebackType
from typing import (
Any,
Callable,
Expand All @@ -45,6 +46,7 @@
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
Expand Down Expand Up @@ -278,6 +280,9 @@ class PdfReader:
password: Decrypt PDF file at initialization. If the
password is None, the file will not be decrypted.
Defaults to ``None``
Can also be instantiated as a contextmanager which will automatically close
the underlying file pointer if passed via filenames.
"""

@property
Expand Down Expand Up @@ -312,8 +317,10 @@ def __init__(
__name__,
)

self._opened_automatically = False
if isinstance(stream, (str, Path)):
stream = FileIO(stream, "rb")
self._opened_automatically = True
weakref.finalize(self, stream.close)

self.read(stream)
Expand Down Expand Up @@ -349,6 +356,20 @@ def close(self) -> None:
"""Close the underlying file handle"""
self.stream.close()

def __enter__(self) -> "PdfReader":
"""Use PdfReader as context manager"""
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Close the underlying stream if owned by the PdfReader"""
if self._opened_automatically:
self.close()

@property
def root_object(self) -> DictionaryObject:
"""Provide access to "/Root". standardized with PdfWriter."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,8 @@ def test_extract_text_hello_world():

def test_read_path():
path = Path(RESOURCE_ROOT, "crazyones.pdf")
reader = PdfReader(path)
assert len(reader.pages) == 1
with PdfReader(path) as reader:
assert len(reader.pages) == 1


def test_read_not_binary_mode(caplog):
Expand Down

0 comments on commit 44a828c

Please sign in to comment.