diff --git a/pypdf/_writer.py b/pypdf/_writer.py index 0ac1524bc..2f110340c 100644 --- a/pypdf/_writer.py +++ b/pypdf/_writer.py @@ -152,11 +152,16 @@ class PdfWriter(PdfDocCommon): Typically data is added from a :class:`PdfReader`. Args: + *: 1st argument is assigned to fileobj or clone_from based on context: + assigned to clone_from if str/path to a non empty file or stream or PdfReader, + otherwise assigned to fileobj. + + fileobj: Output file/stream. To be used with context manager only. + clone_from: identical to fileobj (for compatibility) incremental: If true, loads the document and set the PdfWriter in incremental mode. - When writing incrementally, the original document is written first and new/modified content is appended. To be used for signed document/forms to keep signature valid. @@ -166,7 +171,8 @@ class PdfWriter(PdfDocCommon): def __init__( self, - fileobj: Union[None, PdfReader, StrByteType, Path] = "", + *args: Any, + fileobj: Union[None, StrByteType, Path] = "", clone_from: Union[None, PdfReader, StrByteType, Path] = None, incremental: bool = False, full: bool = False, @@ -202,39 +208,40 @@ def __init__( self._ID: Union[ArrayObject, None] = None self._info_obj: Optional[PdfObject] - if self.incremental: - if isinstance(fileobj, (str, Path)): - with open(fileobj, "rb") as f: - fileobj = BytesIO(f.read(-1)) - if isinstance(fileobj, BytesIO): - fileobj = PdfReader(fileobj) - if not isinstance(fileobj, PdfReader): - raise PyPdfError("Invalid type for incremental mode") - self._reader = fileobj # prev content is in _reader.stream - self._header = fileobj.pdf_header.encode() - self._readonly = True # !!!TODO: to be analysed - else: - self._header = b"%PDF-1.3" - self._info_obj = self._add_object( - DictionaryObject( - {NameObject("/Producer"): create_string_object("pypdf")} + manual_set_fileobj = True + if len(args) > 0: + if fileobj == "": + fileobj = args[0] + manual_set_fileobj = False + elif clone_from is None: + clone_from = args[0] + else: + logger_warning( + "unnamed param ignored: fileobj and clone_from already defined", + __name__, ) - ) def _get_clone_from( fileobj: Union[None, PdfReader, str, Path, IO[Any], BytesIO], clone_from: Union[None, PdfReader, str, Path, IO[Any], BytesIO], - ) -> Union[None, PdfReader, str, Path, IO[Any], BytesIO]: - if isinstance(fileobj, (str, Path, IO, BytesIO)) and ( - fileobj == "" or clone_from is not None + manual_set_fileobj: bool, + ) -> Tuple[ + Union[None, PdfReader, str, Path, IO[Any], BytesIO], + Union[None, str, Path, IO[Any], BytesIO], + ]: + if manual_set_fileobj or ( + isinstance(fileobj, (str, Path, IO, BytesIO)) + and (fileobj in ("", None) or clone_from is not None) ): - return clone_from + assert not isinstance(fileobj, PdfReader), "for mypy" + return clone_from, fileobj cloning = True if isinstance(fileobj, (str, Path)) and ( not Path(str(fileobj)).exists() or Path(str(fileobj)).stat().st_size == 0 ): cloning = False + if isinstance(fileobj, (IO, BytesIO)): t = fileobj.tell() fileobj.seek(-1, 2) @@ -242,13 +249,34 @@ def _get_clone_from( cloning = False fileobj.seek(t, 0) if cloning: - clone_from = fileobj - return clone_from + return fileobj, None + assert not isinstance(fileobj, PdfReader), "for mypy" + return clone_from, fileobj + + clone_from, fileobj = _get_clone_from(fileobj, clone_from, manual_set_fileobj) + + if self.incremental: + if isinstance(clone_from, (str, Path)): + with open(clone_from, "rb") as f: + clone_from = BytesIO(f.read(-1)) + if isinstance(clone_from, (IO, BytesIO)): + clone_from = PdfReader(clone_from) + if not isinstance(clone_from, PdfReader): + raise PyPdfError("Invalid type for incremental mode") + self._reader = clone_from # prev content is in _reader.stream + self._header = clone_from.pdf_header.encode() + self._readonly = True # !!!TODO: to be analysed + else: + self._header = b"%PDF-1.3" + self._info_obj = self._add_object( + DictionaryObject( + {NameObject("/Producer"): create_string_object("pypdf")} + ) + ) - clone_from = _get_clone_from(fileobj, clone_from) # to prevent overwriting self.temp_fileobj = fileobj - self.fileobj = "" + self.fileobj: Union[None, StrByteType, Path] = "" self.with_as_usage = False # The root of our page tree node. pages = DictionaryObject() @@ -354,10 +382,8 @@ def xmp_metadata(self, value: Optional[XmpInformation]) -> None: def __enter__(self) -> "PdfWriter": """Store that writer is initialized by 'with'.""" - t = self.temp_fileobj - self.__init__() # type: ignore + self.fileobj = self.temp_fileobj self.with_as_usage = True - self.fileobj = t # type: ignore return self def __exit__( @@ -1393,7 +1419,7 @@ def write(self, stream: Union[Path, StrByteType]) -> Tuple[bool, IO[Any]]: self.write_stream(stream) - if self.with_as_usage: + if my_file: stream.close() return my_file, stream diff --git a/tests/test_writer.py b/tests/test_writer.py index 0cd2d03f8..382b1c26e 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -2480,3 +2480,64 @@ def test_append_pdf_with_dest_without_page(caplog): writer.append(reader) assert "/__WKANCHOR_8" not in writer.named_destinations assert len(writer.named_destinations) == 3 + + +def test_writer_contextmanager(tmp_path, caplog): + """To test the writer with context manager, cf #2912""" + pdf_path = str(RESOURCE_ROOT / "crazyones.pdf") + with PdfWriter(pdf_path) as w: + assert len(w.pages) > 0 + assert not w.fileobj + with open(pdf_path, "rb") as f, PdfWriter(f) as w: + assert len(w.pages) > 0 + assert not w.fileobj + with open(pdf_path, "rb") as f, PdfWriter(BytesIO(f.read(-1))) as w: + assert len(w.pages) > 0 + assert not w.fileobj + + tmp_file = tmp_path / "out.pdf" + with PdfWriter(tmp_file) as w: + assert len(w.pages) == 0 + + with open(tmp_file, "wb") as f1, open(pdf_path, "rb") as f: + f1.write(f.read(-1)) + with PdfWriter(tmp_file) as w: + assert len(w.pages) > 0 + assert tmp_file.stat().st_size > 0 + + with PdfWriter(tmp_file, incremental=True) as w: + assert w._reader + assert not w.fileobj + assert tmp_file.stat().st_size > 0 + + with PdfWriter(clone_from=tmp_file) as w: + assert len(w.pages) > 0 + assert not w.fileobj + assert tmp_file.stat().st_size > 0 + + with PdfWriter(fileobj=tmp_file) as w: + assert len(w.pages) == 0 + assert 8 <= tmp_file.stat().st_size <= 1024 + + b = BytesIO() + with PdfWriter(fileobj=b) as w: + assert len(w.pages) == 0 + assert not b.closed + assert 8 <= len(b.getbuffer()) <= 1024 + + with NamedTemporaryFile(mode="wb", suffix=".pdf", delete=True) as tmp: + with PdfWriter(pdf_path, fileobj=tmp, incremental=True) as w: + assert w._reader + assert not tmp.closed + assert Path(tmp.name).stat().st_size == Path(pdf_path).stat().st_size + + with PdfWriter(tmp_file) as w: + assert len(w.pages) == 0 + + caplog.clear() + b = BytesIO() + with PdfWriter("ignored", fileobj=b, clone_from=pdf_path) as w: + pass + assert ( + "unnamed param ignored: fileobj and clone_from already defined" in caplog.text + )