From 7dbb4eb222b6a5e449db4e55e49b54b0f5d09fd8 Mon Sep 17 00:00:00 2001 From: edielam Date: Wed, 27 Sep 2023 03:09:02 -0400 Subject: [PATCH] reformatting the container_types function in the jax backend --- ivy/functional/backends/jax/general.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index 6bc3efbec82d3..329c989a03e87 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -7,7 +7,7 @@ from numbers import Number from operator import mul from functools import reduce as _reduce -from typing import Optional, Union, Sequence, Callable, Tuple +from typing import Optional, Union, Sequence, Callable, Tuple, List, Type import multiprocessing as _multiprocessing import importlib @@ -22,7 +22,17 @@ from . import backend_version -def container_types(): +def container_types() -> List[Type]: + """ + Gets list of container types supported. + + Returns: + List[Type]: List containing the FlatMapping container type. + + Examples: + >>> container_types() + [FlatMapping] + """ flat_mapping_spec = importlib.util.find_spec( "FlatMapping", "haiku._src.data_structures" ) @@ -30,6 +40,7 @@ def container_types(): from haiku._src.data_structures import FlatMapping else: FlatMapping = importlib.util.module_from_spec(flat_mapping_spec) + return [FlatMapping]