-
-
Notifications
You must be signed in to change notification settings - Fork 876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
skip the gpu memory checks if the device is set to 'auto' #609
Conversation
8e36314
to
972bd42
Compare
Unfortunately I haven't been able to get this branch to work, neither with The same error keeps rising. I've added a little bit of logging to try to understand where the error is located: import logging
LOG = logging.getLogger("axolotl")
def check_cuda_device(default_value):
"""
wraps a function and returns the default value instead of running the
wrapped function if cuda isn't available or the device is auto
:param default_value:
:return:
"""
def actual_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)
LOG.debug(f"Device in check_cuda_device: {device}")
LOG.debug(f"Default value on check_cuda_device: {default_value}")
if not torch.cuda.is_available() or device == "auto":
return default_value
return func(*args, **kwargs)
return wrapper
return actual_decorator When i examine the output of the error I find something interesting:
Before loading the model, log traces are:
So I assume GPU gets picked up correctly, but after the model is loaded, this are the logs:
And suddenly the device detected by pytorch is cpu. |
I added some print steps, and it shows |
8efef9a
to
07c9436
Compare
…-cloud#609) * skip the gpu memory checks if the device is set to 'auto' * skip gpu mem logging if cpu too * don't worry about log_gpu_memory_usage since it calls another annotated fn * rename decorator internal
resolves #456