From 44a828c0321f572fee1dc10fb3c24f9058f00153 Mon Sep 17 00:00:00 2001 From: Marco Sirabella Date: Wed, 20 Mar 2024 17:26:30 -0700 Subject: [PATCH] MAINT: Allow opening PdfReader as contextmanager To mirror PdfWriter, also hints towards file pointer management now that we keep files open sometimes. --- pypdf/_reader.py | 21 +++++++++++++++++++++ tests/test_reader.py | 4 ++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pypdf/_reader.py b/pypdf/_reader.py index 9bd25852cd..8a54858447 100644 --- a/pypdf/_reader.py +++ b/pypdf/_reader.py @@ -35,6 +35,7 @@ from datetime import datetime from io import BytesIO, FileIO, UnsupportedOperation from pathlib import Path +from types import TracebackType from typing import ( Any, Callable, @@ -45,6 +46,7 @@ Mapping, Optional, Tuple, + Type, Union, cast, ) @@ -278,6 +280,9 @@ class PdfReader: password: Decrypt PDF file at initialization. If the password is None, the file will not be decrypted. Defaults to ``None`` + + Can also be instantiated as a contextmanager which will automatically close + the underlying file pointer if passed via filenames. """ @property @@ -312,8 +317,10 @@ def __init__( __name__, ) + self._opened_automatically = False if isinstance(stream, (str, Path)): stream = FileIO(stream, "rb") + self._opened_automatically = True weakref.finalize(self, stream.close) self.read(stream) @@ -349,6 +356,20 @@ def close(self) -> None: """Close the underlying file handle""" self.stream.close() + def __enter__(self) -> "PdfReader": + """Use PdfReader as context manager""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + """Close the underlying stream if owned by the PdfReader""" + if self._opened_automatically: + self.close() + @property def root_object(self) -> DictionaryObject: """Provide access to "/Root". standardized with PdfWriter.""" diff --git a/tests/test_reader.py b/tests/test_reader.py index c9c6be9b3b..0727ecfe17 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -849,8 +849,8 @@ def test_extract_text_hello_world(): def test_read_path(): path = Path(RESOURCE_ROOT, "crazyones.pdf") - reader = PdfReader(path) - assert len(reader.pages) == 1 + with PdfReader(path) as reader: + assert len(reader.pages) == 1 def test_read_not_binary_mode(caplog):