Skip to content

Commit

Permalink
fix: fixes typing issues discovered from github api generation
Browse files Browse the repository at this point in the history
  • Loading branch information
andrueastman committed Nov 7, 2024
1 parent 72eb943 commit 6e68068
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def send_collection_async(
async def send_collection_of_primitive_async(
self,
request_info: RequestInformation,
response_type: ResponseType,
response_type: type[ResponseType],
error_map: Optional[Dict[str, type[ParsableFactory]]],
) -> Optional[List[ResponseType]]:
"""Excutes the HTTP request specified by the given RequestInformation and returns the
Expand Down
17 changes: 6 additions & 11 deletions packages/abstractions/kiota_abstractions/request_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from .request_adapter import RequestAdapter

Url = str
T = TypeVar("T", bound=Parsable)
T = TypeVar("T", bool, str, int, float, UUID, datetime, timedelta, date, time, bytes)
U = TypeVar("U", bound=Parsable)
QueryParameters = TypeVar('QueryParameters')
OBSERVABILITY_TRACER_NAME = "microsoft-python-kiota-abstractions"
tracer = trace.get_tracer(OBSERVABILITY_TRACER_NAME, VERSION)
Expand Down Expand Up @@ -155,20 +156,20 @@ def set_content_from_parsable(
self,
request_adapter: RequestAdapter,
content_type: str,
values: Union[T, List[T]],
values: Union[U, List[U]],
) -> None:
"""Sets the request body from a model with the specified content type.
Args:
request_adapter (Optional[RequestAdapter]): The adapter service to get the serialization
writer from.
content_type (Optional[str]): the content type.
values (Union[T, List[T]]): the models.
values (Union[U, List[U]]): the models.
"""
with tracer.start_as_current_span(
self._create_parent_span_name("set_content_from_parsable")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)
writer = self._get_serialization_writer(request_adapter, content_type, span)
if isinstance(values, MultipartBody):
content_type += f"; boundary={values.boundary}"
values.request_adapter = request_adapter
Expand Down Expand Up @@ -198,7 +199,7 @@ def set_content_from_scalar(
with tracer.start_as_current_span(
self._create_parent_span_name("set_content_from_scalar")
) as span:
writer = self._get_serialization_writer(request_adapter, content_type, values, span)
writer = self._get_serialization_writer(request_adapter, content_type, span)

if isinstance(values, list):
writer.writer = writer.write_collection_of_primitive_values(None, values)
Expand Down Expand Up @@ -255,15 +256,13 @@ def _get_serialization_writer(
self,
request_adapter: Optional["RequestAdapter"],
content_type: Optional[str],
values: Union[T, List[T]],
parent_span: trace.Span,
):
"""_summary_
Args:
request_adapter (RequestAdapter): _description_
content_type (str): _description_
values (Union[T, List[T]]): _description_
"""
_span = self._start_local_tracing_span("_get_serialization_writer", parent_span)
try:
Expand All @@ -275,10 +274,6 @@ def _get_serialization_writer(
exc = ValueError("Content Type cannot be null")
_span.record_exception(exc)
raise exc
if not values:
exc = ValueError("Values cannot be null")
_span.record_exception(exc)
raise exc
return request_adapter.get_serialization_writer_factory(
).get_serialization_writer(content_type)
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_time_value(self) -> Optional[time]:
pass

@abstractmethod
def get_collection_of_primitive_values(self, primitive_type) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -128,7 +128,7 @@ def get_collection_of_primitive_values(self, primitive_type) -> Optional[List[T]
pass

@abstractmethod
def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
"""Gets the collection of model object values of the node
Args:
factory (ParsableFactory): The factory to use to create the model object.
Expand Down
2 changes: 1 addition & 1 deletion packages/http/httpx/kiota_http/httpx_request_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def send_collection_async(
async def send_collection_of_primitive_async(
self,
request_info: RequestInformation,
response_type: ResponseType,
response_type: type[ResponseType],
error_map: Optional[Dict[str, type[ParsableFactory]]],
) -> Optional[List[ResponseType]]:
"""Excutes the HTTP request specified by the given RequestInformation and returns the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_child_node(self, field_name: str) -> Optional[ParseNode]:
return FormParseNode(self._fields[field_name])
return None

def get_collection_of_primitive_values(self, primitive_type: type) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -189,7 +189,7 @@ def get_collection_of_primitive_values(self, primitive_type: type) -> Optional[L
return result
raise Exception(f"Encountered an unknown type during deserialization {primitive_type}")

def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
raise Exception("Collection of object values is not supported with uri form encoding.")

def get_collection_of_enum_values(self, enum_class: K) -> Optional[List[K]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_time_value(self) -> Optional[time]:
return datetime_obj
return None

def get_collection_of_primitive_values(self, primitive_type: Any) -> Optional[List[T]]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> Optional[List[T]]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -161,7 +161,7 @@ def func(item):
return list(map(func, json.loads(self._json_node)))
return list(map(func, list(self._json_node)))

def get_collection_of_object_values(self, factory: ParsableFactory) -> Optional[List[U]]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> Optional[List[U]]:
"""Gets the collection of type U values from the json node
Returns:
List[U]: The collection of model object values of the node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_time_value(self) -> Optional[time]:
return datetime_obj.time()
return None

def get_collection_of_primitive_values(self, primitive_type) -> List[T]:
def get_collection_of_primitive_values(self, primitive_type: type[T]) -> List[T]:
"""Gets the collection of primitive values of the node
Args:
primitive_type: The type of primitive to return.
Expand All @@ -142,7 +142,7 @@ def get_collection_of_primitive_values(self, primitive_type) -> List[T]:
"""
raise Exception(self.NO_STRUCTURED_DATA_MESSAGE)

def get_collection_of_object_values(self, factory: ParsableFactory) -> List[U]:
def get_collection_of_object_values(self, factory: ParsableFactory[U]) -> List[U]:
"""Gets the collection of type U values from the text node
Returns:
List[U]: The collection of model object values of the node
Expand Down

0 comments on commit 6e68068

Please sign in to comment.