From 746032ba57b44038ac3fbfaa1632af86e69b1050 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 26 Jul 2024 05:24:09 -0700 Subject: [PATCH] use a torch.compile friendly pytree --- frame_averaging_pytorch/frame_averaging.py | 2 +- pyproject.toml | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/frame_averaging_pytorch/frame_averaging.py b/frame_averaging_pytorch/frame_averaging.py index 8268433..52bce7c 100644 --- a/frame_averaging_pytorch/frame_averaging.py +++ b/frame_averaging_pytorch/frame_averaging.py @@ -1,10 +1,10 @@ from __future__ import annotations from random import randrange -from optree import tree_map import torch from torch.nn import Module +from torch.utils._pytree import tree_map from einops import rearrange, repeat, reduce, einsum diff --git a/pyproject.toml b/pyproject.toml index f6009cb..e0a7ec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "frame-averaging-pytorch" -version = "0.1.1" +version = "0.1.2" description = "Frame Averaging" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -25,7 +25,6 @@ classifiers=[ dependencies = [ "torch>=2.0", "einops>=0.8.0", - "optree" ] [project.urls]