From 0d76cc15848336e06c3f892ace5107ec4fa5fbfb Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 11 Jul 2023 19:58:04 +0000 Subject: [PATCH] [FSDP][Easy] Allow `ModuleWrapPolicy` to take `Iterable` ghstack-source-id: 6e56f5c32da86336e4a943a49f237d448c532cad Pull Request resolved: https://github.com/pytorch/pytorch/pull/104999 --- torch/distributed/fsdp/wrap.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index b47eebe7ae3a2..baf8aa72de72a 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -198,13 +198,14 @@ def _module_wrap_policy( class ModuleWrapPolicy(_FSDPPolicy): """This is a wrapper around :func:`_module_wrap_policy`.""" - def __init__(self, module_classes: Set[Type[nn.Module]]): + def __init__(self, module_classes: Iterable[Type[nn.Module]]): + module_classes_set = set(module_classes) self._policy: Callable = functools.partial( _module_wrap_policy, - module_classes=module_classes, + module_classes=module_classes_set, ) - self._module_classes = module_classes - self._module_classes_str = str(module_classes) + self._module_classes = module_classes_set + self._module_classes_str = str(module_classes_set) @property def policy(self):