-
Notifications
You must be signed in to change notification settings - Fork 901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Device support improvements (MPS) #1054
Conversation
Thank you for this! This is very nice PR, but the modification is so large that it will take some time to verify. Please understand. |
Thank you for the great work! I have started my review and have a few concerns. First, I do not like the use of environment variables. It also makes breaking changes to the existing environment. Is it possible to get the target device from Also, I do not want to change the IPEX part as it is maintained by another author and not by me. I would like to keep it separate from the rest of the scripts as much as possible. Is it possible to take these into account? |
What breaking changes are you referring to? 🤔 If no
Sure, that could be used too, but that would certainly be a change from the previous behavior.
I can move the IPEX utility into a separate file, sure. The change itself (that files are less "littered" with the IPEX initialization code with its try-excepts and all) is for the better, though, I think? EDIT: I moved the IPEX stuff to #1060 instead to keep this simple. |
bf120b4
to
34d5260
Compare
Oh, sorry, I didn't notice that. That's nice! However, I prefer not to use the environment variables... I would like to have all settings in one place.
In my understanding, if we use the device from
Thank you! However, I just don't want to touch any part of IPEX to keep the responsibility separate😅 |
I changed this to prefer the Accelerate device. It will be close enough, unless someone is actually using a distributed device type for Accelerate, and still wants to do inference for with a local device (CUDA, MPS, CPU). Those users are probably few and far between, and I'm sure they'll be technical enough to find a workaround 😄 As for environment variables,
Aye, I pinged the IPEX maintainer in that PR. It's now a very simple refactoring of the repeated initialization code to a separate function, so it's functionally the same but keeps the IPEX stuff even better encapsulated out of the main training scripts. Oh, and by the way, thank you for your work on |
Thank you for updating! It looks like In addition, I am very sorry that I did not notice this in the previous comment, it just doesn't seem to work correctly when using sd-scripts without setting up accelerate (using LoRA related utilities, or using model conversion only). So, I think, in the training script, it is a redundant but simple and effective way to get the device from |
No worries – can you give me an example command line (and the resulting error/traceback) that doesn't work? I tried with a The code here currently does try to fall back to CUDA/MPS/CPU if there's any issue talking to Accelerate. |
Oh, sorry. Even if we remove So I think the simple way may be better. |
@kohya-ss Understood – done! |
Thank you for updating! I will review and merge this sooner! |
I'll rebase this now that #1060 was merged. |
32dca12
to
04cef21
Compare
@kohya-ss Rebased, ready for review again. By the way, would you prefer these PRs to target |
Thank you! I prefer |
Thank you again for the PR. I have merged into the my new branch, and will add some little modification. I will merge the branch to dev (and main) sooner. I appreciate your understanding. |
@kohya-ss Sure thing. Thank you for all of your work! 😄 |
Importing device_utils before running ipex_init() breaks ipex support. torch.cuda.is_available() returns false without ipex_init and sd-scripts uses cpu. |
This PR follows what #666 started to make things work better for MPS (Apple Silicon).
It turns out that
accelerate
already prefers MPS by default, but the third commit in the PR makes theaccelerate
device more obvious to the user.Beyond that, this PR:
cuda.empty_cache()
,gc.collect()
) into a single function and teaches it about MPSrefactors the duplicated code to initialize IPEX (XPU) into a single functionMoved to Deduplicate ipex initialization code #1060cuda
, thenmps
, thencpu
.I tested that training basically works on my machine, but I don't have an XPU or suitable CUDA machine to test the other changes on.