diff --git a/ivy/functional/backends/jax/general.py b/ivy/functional/backends/jax/general.py index 6199d7585877..7ddc2c81d093 100644 --- a/ivy/functional/backends/jax/general.py +++ b/ivy/functional/backends/jax/general.py @@ -6,7 +6,7 @@ from functools import reduce as _reduce from numbers import Number from operator import mul -from typing import Callable, Optional, Sequence, Tuple, Union +from typing import Optional, Union, Sequence, Callable, Tuple, List, Type # global import jax @@ -24,7 +24,17 @@ from . import backend_version -def container_types(): +def container_types() -> List[Type]: + """ + Gets list of container types supported. + + Returns: + List[Type]: List containing the FlatMapping container type. + + Examples: + >>> container_types() + [FlatMapping] + """ flat_mapping_spec = importlib.util.find_spec( "FlatMapping", "haiku._src.data_structures" ) @@ -32,6 +42,7 @@ def container_types(): from haiku._src.data_structures import FlatMapping else: FlatMapping = importlib.util.module_from_spec(flat_mapping_spec) + return [FlatMapping]