Skip to content

Commit

Permalink
Update unsloth for torch.cuda.amp deprecation (#2042)
Browse files Browse the repository at this point in the history
* update deprecated unsloth tirch cuda amp  decorator

* WIP fix torch.cuda.amp deprecation

* lint

* laxing torch version requirement

* remove use of partial

* remove use of partial

* lint

---------

Co-authored-by: sunny <sunnyliu19981005@gmail.com>
  • Loading branch information
bursteratom and sunny authored Nov 13, 2024
1 parent c5eb9ea commit 342935c
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/axolotl/utils/gradient_checkpointing/unsloth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from packaging import version

torch_version = version.parse(torch.__version__)

if torch_version < version.parse("2.4.0"):
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")


class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
Expand All @@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
"""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch_cuda_amp_custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
Expand All @@ -36,7 +46,7 @@ def forward(ctx, forward_function, hidden_states, *args):
return output

@staticmethod
@torch.cuda.amp.custom_bwd
@torch_cuda_amp_custom_bwd
def backward(ctx, dY):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
Expand Down

0 comments on commit 342935c

Please sign in to comment.