diff --git a/frame_averaging_pytorch/frame_averaging.py b/frame_averaging_pytorch/frame_averaging.py index 8268433..52bce7c 100644 --- a/frame_averaging_pytorch/frame_averaging.py +++ b/frame_averaging_pytorch/frame_averaging.py @@ -1,10 +1,10 @@ from __future__ import annotations from random import randrange -from optree import tree_map import torch from torch.nn import Module +from torch.utils._pytree import tree_map from einops import rearrange, repeat, reduce, einsum diff --git a/pyproject.toml b/pyproject.toml index f6009cb..e0a7ec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "frame-averaging-pytorch" -version = "0.1.1" +version = "0.1.2" description = "Frame Averaging" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -25,7 +25,6 @@ classifiers=[ dependencies = [ "torch>=2.0", "einops>=0.8.0", - "optree" ] [project.urls]