Skip to content

Commit

Permalink
add dp pg
Browse files Browse the repository at this point in the history
  • Loading branch information
zdong1 committed Nov 10, 2023
1 parent 7574a01 commit f8f5bc4
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions src/pg/pg_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List, Tuple, Callable
import torch

def compute_gradient_estimate(
theta: torch.Tensor,
trajectories: List[List[Tuple[any, any]]],
reward_function: Callable[[any, any], float],
policy_function: Callable[[torch.Tensor, any, any], torch.Tensor]
) -> torch.Tensor:
"""
Computes the gradient estimate of J(theta) for policy gradient methods.
Parameters:
theta (torch.Tensor): Parameters of the neural network policy, with gradient tracking enabled.
trajectories (List[List[Tuple[any, any]]]): List of trajectories, each trajectory is a list of (state, action) tuples.
reward_function (Callable[[any, any], float]): Function that computes the reward for a state-action pair.
policy_function (Callable[[torch.Tensor, any, any], torch.Tensor]): Function that computes the probability of an action given a state under the policy.
Returns:
torch.Tensor: The gradient estimate of J(theta).
Raises:
ValueError: If the input parameters are not in the expected format or type.
"""

# Validate inputs
if not isinstance(theta, torch.Tensor):
raise ValueError("Theta must be a PyTorch Tensor.")
if not theta.requires_grad:
raise ValueError("Theta must require gradient.")
if not isinstance(trajectories, list) or not all(isinstance(traj, list) for traj in trajectories):
raise ValueError("Trajectories must be a list of list of tuples.")
if not callable(reward_function) or not callable(policy_function):
raise ValueError("Reward and policy functions must be callable.")

m = len(trajectories) # Number of trajectories
gradient_sum = torch.zeros_like(theta) # Initialize the gradient sum

for trajectory in trajectories:
# Compute the total reward for the trajectory
total_reward = sum(reward_function(state, action) for state, action in trajectory)

gradient = torch.zeros_like(theta) # Initialize the gradient for this trajectory
for state, action in trajectory:
log_prob = policy_function(theta, action, state).log()
# Compute gradient for the log probability
gradient += torch.autograd.grad(log_prob, theta, retain_graph=True)[0]

gradient_sum += total_reward * gradient # Accumulate the weighted gradient

# Compute the average gradient
average_gradient = gradient_sum / m

return average_gradient

0 comments on commit f8f5bc4

Please sign in to comment.