Skip to content

Commit

Permalink
add gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
zdong1 committed Nov 10, 2023
1 parent 557e3d1 commit 160b268
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/pg/pg_jdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def compute_gradient_estimate_jdp(
for trajectory in trajectories:
# Clip the total reward to adhere to the sensitivity bounds
total_reward = min(sum(reward_function(state, action) for state, action in trajectory), R_max)

gradient_per_trajectory = [torch.zeros_like(param) for param in theta]
# true gradient
for state, action in trajectory:
log_prob = torch.log(policy_function(theta, action, state))

Expand All @@ -129,11 +130,17 @@ def compute_gradient_estimate_jdp(
if param.grad is not None:
param.grad.zero_()
log_prob.backward(retain_graph=True)
gradients[i] += param.grad * total_reward

with torch.no_grad():
grad_norm = torch.norm(param.grad)
# clipped_grad = param.grad * min(G/ (grad_norm + 1e-6), 1.0)
gradient_per_trajectory[i] += param.grad * total_reward /m
# add clipping to true gradient
gpt_norm = torch.norm(gradient_per_trajectory.grad)
clipped_grad = gradient_per_trajectory * min(G/ (gpt_norm + 1e-6), 1.0)
param.grad += clipped_grad
# Adding Gaussian noise for differential privacy
noisy_gradients = [
gradient / m + torch.normal(0, np.sqrt(sigma_squared), size=gradient.shape)
gradient + torch.normal(0, np.sqrt(sigma_squared), size=gradient.shape)
for gradient in gradients
]

Expand Down

0 comments on commit 160b268

Please sign in to comment.