diff --git a/CHANGELOG.md b/CHANGELOG.md index f3b052c..cdc57ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.3.0] - 2024-02-28 + +### Added +- Added multipart body class to support multipart serialization. [microsoft/kiota#3030](https://github.com/microsoft/kiota/issues/3030) + +### Changed + ## [1.2.0] - 2024-01-31 ### Added diff --git a/kiota_abstractions/_version.py b/kiota_abstractions/_version.py index 42c1c2b..96ce313 100644 --- a/kiota_abstractions/_version.py +++ b/kiota_abstractions/_version.py @@ -1 +1 @@ -VERSION: str = "1.2.0" +VERSION: str = "1.3.0" diff --git a/kiota_abstractions/multipart_body.py b/kiota_abstractions/multipart_body.py new file mode 100644 index 0000000..56d121a --- /dev/null +++ b/kiota_abstractions/multipart_body.py @@ -0,0 +1,144 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. All Rights Reserved. +# Licensed under the MIT License. +# See License in the project root for license information. +# ------------------------------------ +from __future__ import annotations + +import io +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, TypeVar + +from .serialization import Parsable + +if TYPE_CHECKING: + from .request_adapter import RequestAdapter + from .serialization import ParseNode, SerializationWriter + +T = TypeVar("T") + + +@dataclass +class MultipartBody(Parsable, Generic[T]): + """Represents a multipart body for a request or a response. + Example usage: + multipart = MultipartBody() + multipart.add_or_replace_part("file", "image/jpeg", open("image.jpg", "rb").read()) + multipart.add_or_replace_part("text", "text/plain", "Hello, World!") + with open("output.txt", "w") as output_file: + multipart.serialize(output_file) + """ + boundary: str = str(uuid.uuid4()) + parts: Dict[str, Tuple[str, Any]] = field(default_factory=dict) + request_adapter: Optional[RequestAdapter] = None + + def add_or_replace_part(self, part_name: str, content_type: str, part_value: T) -> None: + """Adds or replaces a part to the multipart body. + + Args: + part_name (str): The name of the part to add or replace. + content_type (str): The content type of the part. + part_value (T): The value of the part. + + Returns: + None + """ + if not part_name: + raise ValueError("Part name cannot be null") + if not content_type: + raise ValueError("Content type cannot be null") + if not part_value: + raise ValueError("Part value cannot be null") + value: Tuple[str, Any] = (content_type, part_value) + self.parts[self._normalize_part_name(part_name)] = value + + def get_part_value(self, part_name: str) -> T: + """Gets the value of a part from the multipart body.""" + if not part_name: + raise ValueError("Part name cannot be null") + value = self.parts.get(self._normalize_part_name(part_name)) + return value[1] if value else None + + def remove_part(self, part_name: str) -> bool: + """Removes a part from the multipart body. + + Args: + part_name (str): The name of the part to remove. + + Returns: + bool: True if the part was removed, False otherwise. + """ + if not part_name: + raise ValueError("Part name cannot be null") + return self.parts.pop(self._normalize_part_name(part_name), None) is not None + + def get_field_deserializers(self) -> Dict[str, Callable[[ParseNode], None]]: + """Gets the deserialization information for this object. + + Returns: + Dict[str, Callable[[ParseNode], None]]: The deserialization information for this + object where each entry is a property key with its deserialization callback. + """ + raise NotImplementedError() + + def serialize(self, writer: SerializationWriter) -> None: + """Writes the objects properties to the current writer. + + Args: + writer (SerializationWriter): The writer to write to. + """ + if not writer: + raise ValueError("Serialization writer cannot be null") + if not self.request_adapter or not self.request_adapter.get_serialization_writer_factory(): + raise ValueError("Request adapter or serialization writer factory cannot be null") + if not self.parts: + raise ValueError("No parts to serialize") + + first = True + for part_name, part_value in self.parts.items(): + if first: + first = False + else: + self._add_new_line(writer) + + writer.write_str_value("", f"--{self.boundary}") + writer.write_str_value("Content-Type", f"{part_value[0]}") + writer.write_str_value("Content-Disposition", f'form-data; name="{part_name}"') + self._add_new_line(writer) + + if isinstance(part_value[1], Parsable): + self._write_parsable(writer, part_value[1]) + elif isinstance(part_value[1], str): + writer.write_str_value("", part_value[1]) + elif isinstance(part_value[1], bytes): + writer.write_bytes_value("", part_value[1]) + elif isinstance(part_value[1], io.IOBase): + writer.write_bytes_value("", part_value[1].read()) + else: + raise ValueError(f"Unsupported type {type(part_value[1])} for part {part_name}") + + self._add_new_line(writer) + writer.write_str_value("", f"--{self.boundary}--") + + def _normalize_part_name(self, original: str) -> str: + return original.lower() + + def _add_new_line(self, writer: SerializationWriter) -> None: + writer.write_str_value("", "") + + def _write_parsable(self, writer, part_value) -> None: + if not self.request_adapter or not self.request_adapter.get_serialization_writer_factory(): + raise ValueError("Request adapter or serialization writer factory cannot be null") + part_writer = ( + self.request_adapter.get_serialization_writer_factory().get_serialization_writer( + part_value[0] + ) + ) + part_writer.write_object_value("", part_value[1], None) + part_content = part_writer.get_serialized_content() + if hasattr(part_content, "seek"): # seekable + part_content.seek(0) + writer.write_bytes_value("", part_content.read()) #type: ignore + else: + writer.write_bytes_value("", part_content) diff --git a/kiota_abstractions/request_information.py b/kiota_abstractions/request_information.py index 37e02de..04cf970 100644 --- a/kiota_abstractions/request_information.py +++ b/kiota_abstractions/request_information.py @@ -1,3 +1,10 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. All Rights Reserved. +# Licensed under the MIT License. +# See License in the project root for license information. +# ------------------------------------ +from __future__ import annotations + from dataclasses import fields, is_dataclass from datetime import date, datetime, time, timedelta from enum import Enum @@ -13,6 +20,7 @@ from .base_request_configuration import RequestConfiguration from .headers_collection import HeadersCollection from .method import Method +from .multipart_body import MultipartBody from .request_option import RequestOption from .serialization import Parsable, SerializationWriter @@ -69,7 +77,7 @@ def __init__( self.headers: HeadersCollection = HeadersCollection() # The Request Body - self.content: Optional[BytesIO] = None + self.content: Optional[bytes] = None def configure(self, request_configuration: RequestConfiguration) -> None: """Configures the current request information headers, query parameters, and options @@ -145,8 +153,8 @@ def remove_request_options(self, options: List[RequestOption]) -> None: def set_content_from_parsable( self, - request_adapter: Optional["RequestAdapter"], - content_type: Optional[str], + request_adapter: RequestAdapter, + content_type: str, values: Union[T, List[T]], ) -> None: """Sets the request body from a model with the specified content type. @@ -161,7 +169,9 @@ def set_content_from_parsable( self._create_parent_span_name("set_content_from_parsable") ) as span: writer = self._get_serialization_writer(request_adapter, content_type, values, span) - + if isinstance(values, MultipartBody): + content_type += f"; boundary={values.boundary}" + values.request_adapter = request_adapter if isinstance(values, list): writer.write_collection_of_object_values(None, values) span.set_attribute(self.REQUEST_TYPE_KEY, "[]") @@ -217,7 +227,7 @@ def set_content_from_scalar( writer_func(None, values) self._set_content_and_content_type_header(writer, content_type) - def set_stream_content(self, value: BytesIO, content_type: Optional[str] = None) -> None: + def set_stream_content(self, value: bytes, content_type: Optional[str] = None) -> None: """Sets the request body to be a binary stream. Args: diff --git a/kiota_abstractions/serialization/parse_node.py b/kiota_abstractions/serialization/parse_node.py index 086e8c8..20d0ab9 100644 --- a/kiota_abstractions/serialization/parse_node.py +++ b/kiota_abstractions/serialization/parse_node.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from datetime import date, datetime, time, timedelta from enum import Enum -from io import BytesIO from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar from uuid import UUID @@ -166,11 +165,11 @@ def get_object_value(self, factory: ParsableFactory) -> Parsable: pass @abstractmethod - def get_bytes_value(self) -> BytesIO: - """Get a bytearray value from the nodes + def get_bytes_value(self) -> bytes: + """Get a bytes value from the nodes Returns: - bytearray: The bytearray value from the nodes + bytes: The bytes value from the nodes """ pass diff --git a/kiota_abstractions/serialization/parse_node_factory.py b/kiota_abstractions/serialization/parse_node_factory.py index 5f24060..e926a03 100644 --- a/kiota_abstractions/serialization/parse_node_factory.py +++ b/kiota_abstractions/serialization/parse_node_factory.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from io import BytesIO from .parse_node import ParseNode @@ -18,12 +17,12 @@ def get_valid_content_type(self) -> str: pass @abstractmethod - def get_root_parse_node(self, content_type: str, content: BytesIO) -> ParseNode: + def get_root_parse_node(self, content_type: str, content: bytes) -> ParseNode: """Creates a ParseNode from the given binary stream and content type Args: content_type (str): The content type of the binary stream - content (BytesIO): The array buffer to read from + content (bytes): The array buffer to read from Returns: ParseNode: A ParseNode that can deserialize the given binary stream diff --git a/kiota_abstractions/serialization/parse_node_factory_registry.py b/kiota_abstractions/serialization/parse_node_factory_registry.py index f626dbe..094d63a 100644 --- a/kiota_abstractions/serialization/parse_node_factory_registry.py +++ b/kiota_abstractions/serialization/parse_node_factory_registry.py @@ -1,7 +1,6 @@ from __future__ import annotations import re -from io import BytesIO from typing import Dict from .parse_node import ParseNode @@ -31,7 +30,7 @@ def get_valid_content_type(self) -> str: "The registry supports multiple content types. Get the registered factory instead" ) - def get_root_parse_node(self, content_type: str, content: BytesIO) -> ParseNode: + def get_root_parse_node(self, content_type: str, content: bytes) -> ParseNode: if not content_type: raise Exception("Content type cannot be null") if not content: diff --git a/kiota_abstractions/serialization/parse_node_proxy_factory.py b/kiota_abstractions/serialization/parse_node_proxy_factory.py index 413bd7e..8e23d99 100644 --- a/kiota_abstractions/serialization/parse_node_proxy_factory.py +++ b/kiota_abstractions/serialization/parse_node_proxy_factory.py @@ -4,7 +4,6 @@ # See License in the project root for license information. # ------------------------------------------------------------------------------ -from io import BytesIO from typing import Callable from .parsable import Parsable @@ -44,12 +43,12 @@ def get_valid_content_type(self) -> str: """ return self._concrete.get_valid_content_type() - def get_root_parse_node(self, content_type: str, content: BytesIO) -> ParseNode: + def get_root_parse_node(self, content_type: str, content: bytes) -> ParseNode: """Create a parse node from the given stream and content type. Args: content_type (str): The content type of the parse node. - content (BytesIO): The stream to read the parse node from. + content (bytes): The stream to read the parse node from. Returns: ParseNode: A parse node. diff --git a/kiota_abstractions/serialization/serialization_writer.py b/kiota_abstractions/serialization/serialization_writer.py index bfdb0c9..9377f1c 100644 --- a/kiota_abstractions/serialization/serialization_writer.py +++ b/kiota_abstractions/serialization/serialization_writer.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod from datetime import date, datetime, time, timedelta from enum import Enum -from io import BytesIO from typing import Any, Callable, Dict, List, Optional, TypeVar from uuid import UUID @@ -146,13 +145,13 @@ def write_collection_of_enum_values( pass @abstractmethod - def write_bytes_value(self, key: Optional[str], value: BytesIO) -> None: + def write_bytes_value(self, key: Optional[str], value: bytes) -> None: """Writes the specified byte array as a base64 string to the stream with an optional given key. Args: key (Optional[str]): The key to be used for the written value. May be null. - value (BytesIO): The byte array to be written. + value (bytes): The bytes to be written. """ pass @@ -198,11 +197,11 @@ def write_additional_data_value(self, value: Dict[str, Any]) -> None: pass @abstractmethod - def get_serialized_content(self) -> BytesIO: + def get_serialized_content(self) -> bytes: """Gets the value of the serialized content. Returns: - BytesIO: The value of the serialized content. + bytes: The value of the serialized content. """ pass diff --git a/tests/conftest.py b/tests/conftest.py index 4505197..d30c02a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,11 +9,13 @@ from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import Mock import pytest from kiota_abstractions.authentication.access_token_provider import AccessTokenProvider from kiota_abstractions.authentication.allowed_hosts_validator import AllowedHostsValidator +from kiota_abstractions.multipart_body import MultipartBody from kiota_abstractions.request_adapter import RequestAdapter from kiota_abstractions.request_information import RequestInformation from kiota_abstractions.serialization import ( @@ -21,6 +23,7 @@ Parsable, ParseNode, SerializationWriter, + SerializationWriterFactory ) from kiota_abstractions.store import BackedModel, BackingStore, BackingStoreFactorySingleton @@ -111,6 +114,8 @@ class TestEnum(Enum): VALUE2 = "value2" VALUE3 = "value3" + __test__ = False + @dataclass class QueryParams: dataset: Union[TestEnum, List[TestEnum]] @@ -138,6 +143,21 @@ def mock_access_token_provider(): @pytest.fixture -def mock_request_adapter(mocker): - mocker.patch.multiple(RequestAdapter, __abstractmethods__=set()) - return RequestAdapter() +def mock_request_adapter(): + request_adapter = Mock(spec=RequestAdapter) + return request_adapter + +@pytest.fixture +def mock_serialization_writer(): + return Mock(spec=SerializationWriter) + + +@pytest.fixture +def mock_serialization_writer_factory(): + mock_factory = Mock(spec=SerializationWriterFactory) + return mock_factory + +@pytest.fixture +def mock_multipart_body(): + mock_multipart_body = MultipartBody() + return mock_multipart_body \ No newline at end of file diff --git a/tests/test_multipart_body.py b/tests/test_multipart_body.py new file mode 100644 index 0000000..0431dce --- /dev/null +++ b/tests/test_multipart_body.py @@ -0,0 +1,51 @@ +import pytest + +from kiota_abstractions.serialization import SerializationWriter +from kiota_abstractions.multipart_body import MultipartBody +def test_defensive(): + """Tests initialization of MultipartBody objects + """ + multipart = MultipartBody() + with pytest.raises(ValueError) as excinfo: + multipart.add_or_replace_part(None, "text/plain", "Hello, World!") + assert "Part name cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.add_or_replace_part("text", None, "Hello, World!") + assert "Content type cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.add_or_replace_part("text", "text/plain", None) + assert "Part value cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.get_part_value(None) + assert "Part name cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.get_part_value("") + assert "Part name cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.remove_part(None) + assert "Part name cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.remove_part("") + assert "Part name cannot be null" in str(excinfo.value) + with pytest.raises(ValueError) as excinfo: + multipart.serialize(None) + assert "Serialization writer cannot be null" in str(excinfo.value) + +def test_add_or_replace_part(): + """Tests adding or replacing a part in the multipart body + """ + multipart = MultipartBody() + multipart.add_or_replace_part("text", "text/plain", "Hello, World!") + assert multipart.get_part_value("text") == "Hello, World!" + multipart.add_or_replace_part("text", "text/plain", "Hello, World! 2") + assert multipart.get_part_value("text") == "Hello, World! 2" + +def test_remove_part(): + """Tests removing a part from the multipart body + """ + multipart = MultipartBody() + multipart.add_or_replace_part("text", "text/plain", "Hello, World!") + assert multipart.get_part_value("text") == "Hello, World!" + assert multipart.remove_part("text") + assert not multipart.get_part_value("text") + assert not multipart.remove_part("text") \ No newline at end of file diff --git a/tests/test_request_information.py b/tests/test_request_information.py index aad5209..58e80a7 100644 --- a/tests/test_request_information.py +++ b/tests/test_request_information.py @@ -1,12 +1,14 @@ import pytest from dataclasses import dataclass from typing import Optional +from unittest.mock import Mock from kiota_abstractions.request_information import RequestInformation from kiota_abstractions.headers_collection import HeadersCollection from kiota_abstractions.base_request_configuration import RequestConfiguration from kiota_abstractions.method import Method + from .conftest import TestEnum, QueryParams @@ -132,7 +134,24 @@ def get_query_parameter(self,original_name: Optional[str] = None) -> str: assert mock_request_information.query_parameters == {"%24filter": "query1"} assert not mock_request_information.request_options - +def test_sets_boundary_on_multipart_request_body( + mock_request_information, + mock_request_adapter, + mock_multipart_body, + mock_serialization_writer, + mock_serialization_writer_factory + ): + """Tests setting the boundary on a multipart request + """ + mock_request_information.http_method = Method.POST + mock_serialization_writer_factory.get_serialization_writer = Mock(return_value=mock_serialization_writer) + mock_request_adapter.get_serialization_writer_factory = Mock(return_value=mock_serialization_writer_factory) + mock_multipart_body.request_adapter = mock_request_adapter + mock_request_information.set_content_from_parsable(mock_request_adapter, "multipart/form-data", mock_multipart_body) + assert mock_multipart_body.boundary + assert mock_request_information.headers.get("content-type") == {"multipart/form-data; boundary=" + mock_multipart_body.boundary} + + def test_sets_enum_value_in_query_parameters(): """Tests setting enum values in query parameters """ @@ -162,5 +181,6 @@ def test_sets_enum_values_in_path_parameters(): request_info = RequestInformation(Method.GET, "https://example.com/{dataset}") request_info.path_parameters["dataset"] = [TestEnum.VALUE1, TestEnum.VALUE2] assert request_info.url == "https://example.com/value1%2Cvalue2" + \ No newline at end of file