diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 0c141e82ab52..74db25af64bf 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -445,8 +445,8 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually") def init_model_parallel(self): - from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from megatron.core import parallel_state + from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes for model_module in self: if not self._cpu: