Skip to content

Commit

Permalink
docs: Update algorithm docs
Browse files Browse the repository at this point in the history
Signed-off-by: Anirudh <anirudh@semiotic.ai>
  • Loading branch information
anirudh2 committed Mar 7, 2023
1 parent 67e02dc commit d06dfae
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions autoagora_agents/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@


class Algorithm(ABC):
"""Base class for algorithms.
Concretions must implement :meth:`__call__`.
Attributes:
niterations (int): Number of times the algorithm has been called.
nupdates (int): Number of times the algorithm has been updated.
group (str): The group to which the algorithm belongs.
i (int): The index of the algorithm.
name (str): The group and index of the algorithm.
"""

def __init__(self, *, group: str, i: int) -> None:
self.niterations = 0
self.nupdates = 0
Expand All @@ -21,9 +33,11 @@ def __init__(self, *, group: str, i: int) -> None:
self.name = f"{group}_{i}"

def reset(self) -> None:
"""Reset the algorithm's state."""
self.niterations = 0

def update(self):
def update(self) -> None:
"""Update the algorithm's parameters."""
self.nupdates += 1

@abstractmethod
Expand All @@ -35,6 +49,17 @@ def __call__(
reward: float,
done: bool,
) -> np.ndarray:
"""Run the algorithm forward.
Keyword Arguments:
observation (np.ndarray): The observation seen by the agent.
action (np.ndarray): The previous action taken by the agent.
reward (float): The reward of the agent.
done (bool): If True, the agent is no longer in the game.
Returns:
np.ndarray: The next action taken by the agent.
"""
pass

@staticmethod
Expand Down Expand Up @@ -153,7 +178,15 @@ def __call__(
self.niterations += 1
return act

def logprob(self, actions):
def logprob(self, actions: torch.Tensor) -> torch.Tensor:
"""Compute the log probability of the action given the distribution.
Arguments:
actions (torch.Tensor): The actions for which to compute the log probability
Returns:
torch.Tensor: The log probability of the actions.
"""
return self.actiondist.logprob(actions)


Expand Down

0 comments on commit d06dfae

Please sign in to comment.