Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Add parallel implementation for LKF #3436

Merged
merged 11 commits into from
Mar 11, 2024
176 changes: 140 additions & 36 deletions deepmd/pt/optimizer/LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
import math

import torch
import torch.distributed as dist

Check warning on line 6 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L6

Added line #L6 was not covered by tests
from torch.optim.optimizer import (
Optimizer,
)

log = logging.getLogger(__name__)

def distribute_indices(total_length, num_workers):
indices_per_worker = total_length // num_workers
remainder = total_length % num_workers

Check warning on line 14 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L12-L14

Added lines #L12 - L14 were not covered by tests

indices = []
start = 0

Check warning on line 17 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L16-L17

Added lines #L16 - L17 were not covered by tests

for i in range(num_workers):
end = start + indices_per_worker + (1 if i < remainder else 0)
indices.append((start, end))
start = end

Check warning on line 22 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L19-L22

Added lines #L19 - L22 were not covered by tests

return indices, remainder

Check warning on line 24 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L24

Added line #L24 was not covered by tests


class LKFOptimizer(Optimizer):
Expand All @@ -18,12 +32,13 @@
kalman_nue=0.9987,
block_size=5120,
):
defaults = {
"lr": 0.1,
"kalman_nue": kalman_nue,
"block_size": block_size,
}
super().__init__(params, defaults)
defaults = dict(

Check warning on line 35 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L35

Added line #L35 was not covered by tests
lr=0.1,
kalman_nue=kalman_nue,
block_size=block_size,
)

super(LKFOptimizer, self).__init__(params, defaults)

Check warning on line 41 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L41

Added line #L41 was not covered by tests

self._params = self.param_groups[0]["params"]

Expand All @@ -36,7 +51,10 @@
# the first param, because this helps with casting in load_state_dict
self._state = self.state[self._params[0]]
self._state.setdefault("kalman_lambda", kalman_lambda)

self.dist_init = dist.is_initialized()
self.rank = dist.get_rank() if self.dist_init else 0
self.dindex = []
self.remainder = 0

Check warning on line 57 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L54-L57

Added lines #L54 - L57 were not covered by tests
self.__init_P()

def __init_P(self):
Expand All @@ -61,32 +79,85 @@

P = []
params_packed_index = []
log.info("LKF parameter nums: %s" % param_nums)
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
logging.info("LKF parameter nums: %s" % param_nums)
if self.dist_init:
block_num = 0
for param_num in param_nums:
if param_num >= block_size:
block_num += math.ceil(param_num / block_size)

Check warning on line 87 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L82-L87

Added lines #L82 - L87 were not covered by tests
else:
block_num += 1
num_workers = dist.get_world_size()
self.dindex, self.remainder = distribute_indices(block_num, num_workers)
index = 0
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
device_id = self.get_device_id(index)
index += 1
dist_device = torch.device("cuda:" + str(device_id))
if i != block_num - 1:
params_packed_index.append(block_size)
if self.rank == device_id:
P.append(

Check warning on line 103 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L89-L103

Added lines #L89 - L103 were not covered by tests
torch.eye(
block_size,
dtype=data_type,
device=dist_device,
)
)
else:
continue

Check warning on line 111 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L111

Added line #L111 was not covered by tests
else:
params_packed_index.append(param_num - block_size * i)
if self.rank == device_id:
P.append(

Check warning on line 115 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L113-L115

Added lines #L113 - L115 were not covered by tests
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=dist_device,
)
)
else:
continue

Check warning on line 123 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L123

Added line #L123 was not covered by tests

else:
device_id = self.get_device_id(index)
index += 1
params_packed_index.append(param_num)
if self.rank == device_id:
dist_device = torch.device("cuda:" + str(device_id))

Check warning on line 130 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L126-L130

Added lines #L126 - L130 were not covered by tests
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
torch.eye(param_num, dtype=data_type, device=dist_device)
)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
device_id = self.rank

Check warning on line 134 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L134

Added line #L134 was not covered by tests
Fixed Show fixed Hide fixed
else:
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
P.append(

Check warning on line 141 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L136-L141

Added lines #L136 - L141 were not covered by tests
torch.eye(
block_size,
dtype=data_type,
device=device,
)
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)
params_packed_index.append(block_size)

Check warning on line 148 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L148

Added line #L148 was not covered by tests
else:
P.append(

Check warning on line 150 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L150

Added line #L150 was not covered by tests
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
)
)
params_packed_index.append(param_num - block_size * i)

Check warning on line 157 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L157

Added line #L157 was not covered by tests
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)

Check warning on line 160 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L159-L160

Added lines #L159 - L160 were not covered by tests

self._state.setdefault("P", P)
self._state.setdefault("weights_num", len(P))
Expand Down Expand Up @@ -115,6 +186,8 @@

def __update(self, H, error, weights):
P = self._state.get("P")
# for item in P:
# print(self.rank," size ",item.shape)
Fixed Show fixed Hide fixed
kalman_lambda = self._state.get("kalman_lambda")
weights_num = self._state.get("weights_num")
params_packed_index = self._state.get("params_packed_index")
Expand All @@ -125,16 +198,35 @@
tmp = 0
for i in range(weights_num):
tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i]))

if self.dist_init:
dist.all_reduce(tmp, op=dist.ReduceOp.SUM)

Check warning on line 202 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L201-L202

Added lines #L201 - L202 were not covered by tests
A = 1 / tmp

for i in range(weights_num):
K = torch.matmul(P[i], H[i])

weights[i] = weights[i] + A * error * K

P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T))

if self.dist_init:
device = torch.device("cuda:" + str(self.rank))
local_shape = [tensor.shape[0] for tensor in weights]
shape_list = [

Check warning on line 213 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L210-L213

Added lines #L210 - L213 were not covered by tests
torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device)
for _ in range(dist.get_world_size())
]
dist.all_gather_object(shape_list, local_shape)
weight_tensor = torch.cat(weights)
world_shape = [sum(inner_list) for inner_list in shape_list]
weight_list = [None] * len(world_shape)
for i in range(len(world_shape)):
weight_list[i] = torch.zeros(

Check warning on line 222 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L217-L222

Added lines #L217 - L222 were not covered by tests
world_shape[i], dtype=torch.float64, device=device
)
dist.all_gather(weight_list, weight_tensor)
result = []
for i in range(dist.get_world_size()):
result = result + list(torch.split(weight_list[i], shape_list[i]))
weights = result

Check warning on line 229 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L225-L229

Added lines #L225 - L229 were not covered by tests
kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue
self._state.update({"kalman_lambda": kalman_lambda})

Expand Down Expand Up @@ -215,9 +307,21 @@
param_sum += nelement

if param_sum == params_packed_index[param_index]:
H.append(res_grad)
weights.append(res)
param_sum = 0
if self.dist_init:
device_id = self.get_device_id(param_index)
if self.rank == device_id:
weights.append(res)
H.append(res_grad)

Check warning on line 315 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L311-L315

Added lines #L311 - L315 were not covered by tests
else:
weights.append(res)
H.append(res_grad)

Check warning on line 318 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L317-L318

Added lines #L317 - L318 were not covered by tests
param_index += 1

self.__update(H, error, weights)

def get_device_id(self, index):
for i, (start, end) in enumerate(self.dindex):
if start <= index < end:
return i
return None

Check warning on line 327 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L323-L327

Added lines #L323 - L327 were not covered by tests
Loading