diff --git a/autoagora_agents/algorithm.py b/autoagora_agents/algorithm.py index 6b59f57..35ad39f 100644 --- a/autoagora_agents/algorithm.py +++ b/autoagora_agents/algorithm.py @@ -13,6 +13,18 @@ class Algorithm(ABC): + """Base class for algorithms. + + Concretions must implement :meth:`__call__`. + + Attributes: + niterations (int): Number of times the algorithm has been called. + nupdates (int): Number of times the algorithm has been updated. + group (str): The group to which the algorithm belongs. + i (int): The index of the algorithm. + name (str): The group and index of the algorithm. + """ + def __init__(self, *, group: str, i: int) -> None: self.niterations = 0 self.nupdates = 0 @@ -21,9 +33,11 @@ def __init__(self, *, group: str, i: int) -> None: self.name = f"{group}_{i}" def reset(self) -> None: + """Reset the algorithm's state.""" self.niterations = 0 - def update(self): + def update(self) -> None: + """Update the algorithm's parameters.""" self.nupdates += 1 @abstractmethod @@ -35,6 +49,17 @@ def __call__( reward: float, done: bool, ) -> np.ndarray: + """Run the algorithm forward. + + Keyword Arguments: + observation (np.ndarray): The observation seen by the agent. + action (np.ndarray): The previous action taken by the agent. + reward (float): The reward of the agent. + done (bool): If True, the agent is no longer in the game. + + Returns: + np.ndarray: The next action taken by the agent. + """ pass @staticmethod @@ -153,7 +178,15 @@ def __call__( self.niterations += 1 return act - def logprob(self, actions): + def logprob(self, actions: torch.Tensor) -> torch.Tensor: + """Compute the log probability of the action given the distribution. + + Arguments: + actions (torch.Tensor): The actions for which to compute the log probability + + Returns: + torch.Tensor: The log probability of the actions. + """ return self.actiondist.logprob(actions)