Issue with too much memory usage when running in parallel with mx.distributed #1220
sck-at-ucy
started this conversation in
General
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I am trying to refactor my physics-informed transformer model to run parallel. I am trying to create the various dataset arrays on rank=0, along also empty (zero) counterpart arrays of the same shape on rank > 0 and then use mx.distrubtred.all_sum to get copies of the dataset on all ranks. Then I slice the dataset arrays to create local copies for each rank. Once I no longer need the original complete dataset and I thought the code below would free the memory, but apparently is not, as the code uses way too much memory compared to running on a single node and the entire dataset.
What am I doing wrong? Is there a better memory management approach when using mx.distributed() in the absence of a distributed broadcast?
Would be thankful for any advice from @awni @angeloskath :)
Beta Was this translation helpful? Give feedback.
All reactions