diff --git a/ldp/alg/optimizer/replay_buffers.py b/ldp/alg/optimizer/replay_buffers.py index f6c6839..6c48070 100644 --- a/ldp/alg/optimizer/replay_buffers.py +++ b/ldp/alg/optimizer/replay_buffers.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import asyncio import logging import random +import typing from collections import UserList from collections.abc import Awaitable, Callable, Iterator +from enum import StrEnum, auto +from itertools import chain +from typing import Any, cast import numpy as np import torch -from litellm import cast +from pydantic import BaseModel, ConfigDict, Field from tqdm import tqdm from ldp.graph import eval_mode @@ -15,13 +21,8 @@ logger = logging.getLogger(__name__) -class CircularReplayBuffer(UserList[dict]): - def resize(self, size: int): - if len(self) > size: - self.data = self.data[-size:] - - async def prepare_for_sampling(self): - """Optional method for the buffer to prepare itself before sampling.""" +class ReplayBuffer(UserList[dict]): + """A base replay buffer that only allows adding and sampling.""" @staticmethod def _batched_iter( @@ -61,14 +62,39 @@ def batched_iter( ) -> Iterator[dict]: return self._batched_iter(self.data, batch_size, shuffle, infinite) + def resize(self, size: int) -> None: + """Optional method for the buffer to resize itself.""" + + async def prepare_for_sampling(self) -> None: + """Optional method for the buffer to prepare itself before sampling.""" + + @staticmethod + def sample_from(*buffers: ReplayBuffer, **kwargs) -> Iterator[dict]: + """Helper method to uniformly sample from multiple buffers.""" + if any(isinstance(b, PrioritizedReplayBuffer) for b in buffers): + # This is because PrioritizedReplayBuffer determines samples inside + # batched_iter, so we cannot rely on buffer.data. + raise RuntimeError( + "sample_from does not support prioritized replay buffers" + ) + + all_buffers = list(chain.from_iterable(b.data for b in buffers)) + return ReplayBuffer._batched_iter(data=all_buffers, **kwargs) + -class RandomizedReplayBuffer(CircularReplayBuffer): +class CircularReplayBuffer(ReplayBuffer): + def resize(self, size: int): + if len(self) > size: + self.data = self.data[-size:] + + +class RandomizedReplayBuffer(ReplayBuffer): def resize(self, size: int): if len(self) > size: self.data = random.sample(self.data, size) -class PrioritizedReplayBuffer(CircularReplayBuffer): +class PrioritizedReplayBuffer(ReplayBuffer): """Implements a variant of https://arxiv.org/abs/1511.05952. Instead of updating the TD error on the fly, we compute it for all samples @@ -164,8 +190,35 @@ def batched_iter( ) buffer = [self.data[i] for i in sampled_idcs] - # DEBUG - sampled_priorities = prio[sampled_idcs] - logger.debug(f"Average priority: {sampled_priorities.mean()}") - return self._batched_iter(buffer, batch_size, shuffle, infinite) + + +class ReplayBufferType(StrEnum): + # Maps to different buffer classes + APPEND_ONLY = auto() + CIRCULAR = auto() + RANDOMIZED = auto() + PRIORITIZED = auto() + + +class ReplayBufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + buf_type: ReplayBufferType = Field( + default=ReplayBufferType.CIRCULAR, description="Circular is a good default." + ) + size: int | None = None + kwargs: dict[str, Any] = Field(default_factory=dict) + + def make_buffer(self, **kwargs) -> ReplayBuffer: + # kwargs are only for prioritized case + if self.buf_type == ReplayBufferType.APPEND_ONLY: + return ReplayBuffer() + if self.buf_type == ReplayBufferType.CIRCULAR: + return CircularReplayBuffer() + if self.buf_type == ReplayBufferType.RANDOMIZED: + return RandomizedReplayBuffer() + if self.buf_type == ReplayBufferType.PRIORITIZED: + kwargs = self.kwargs | kwargs + return PrioritizedReplayBuffer(**kwargs) + typing.assert_never(self.buf_type)