From cc0aa38df1696978370e26ada8effc96f1453c07 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 28 May 2024 16:55:25 +0800 Subject: [PATCH] fix(pt): build nlist faster with `torch.amax` (#3826) ## Summary by CodeRabbit - **Performance Improvements** - Enhanced the performance of coordinate extension in neighbor lists by optimizing internal computations. --- deepmd/pt/utils/nlist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index cdee6e3722..a24a5aef72 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -326,7 +326,7 @@ def extend_coord_with_ghosts( # +1: central cell nbuff = torch.ceil(rcut / to_face).to(torch.long) # 3 - nbuff = torch.max(nbuff, dim=0, keepdim=False).values + nbuff = torch.amax(nbuff, dim=0) # faster than torch.max nbuff_cpu = nbuff.cpu() xi = torch.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu") yi = torch.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1, device="cpu")