Skip to content

Commit

Permalink
buffer base class; sample_from
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan committed Dec 4, 2024
1 parent 9b18617 commit c0091e0
Showing 1 changed file with 66 additions and 14 deletions.
80 changes: 66 additions & 14 deletions ldp/alg/optimizer/replay_buffers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import asyncio
import logging
import random
import typing
from collections import UserList
from collections.abc import Awaitable, Callable, Iterator
from enum import StrEnum, auto
from itertools import chain
from typing import Any, cast

import numpy as np
import torch
from litellm import cast
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm

from ldp.graph import eval_mode
Expand All @@ -15,13 +21,8 @@
logger = logging.getLogger(__name__)


class CircularReplayBuffer(UserList[dict]):
def resize(self, size: int):
if len(self) > size:
self.data = self.data[-size:]

async def prepare_for_sampling(self):
"""Optional method for the buffer to prepare itself before sampling."""
class ReplayBuffer(UserList[dict]):
"""A base replay buffer that only allows adding and sampling."""

@staticmethod
def _batched_iter(
Expand Down Expand Up @@ -61,14 +62,39 @@ def batched_iter(
) -> Iterator[dict]:
return self._batched_iter(self.data, batch_size, shuffle, infinite)

def resize(self, size: int):
"""Optional method for the buffer to resize itself."""

async def prepare_for_sampling(self):
"""Optional method for the buffer to prepare itself before sampling."""

@staticmethod
def sample_from(*buffers: ReplayBuffer, **kwargs) -> Iterator[dict]:
"""Helper method to uniformly sample from multiple buffers."""
if any(isinstance(b, PrioritizedReplayBuffer) for b in buffers):
# This is because PrioritizedReplayBuffer determines samples inside
# batched_iter, so we cannot rely on buffer.data.
raise RuntimeError(
"sample_from does not support prioritized replay buffers"
)

all_buffers = list(chain.from_iterable(b.data for b in buffers))
return ReplayBuffer._batched_iter(data=all_buffers, **kwargs)


class CircularReplayBuffer(ReplayBuffer):
def resize(self, size: int):
if len(self) > size:
self.data = self.data[-size:]


class RandomizedReplayBuffer(CircularReplayBuffer):
class RandomizedReplayBuffer(ReplayBuffer):
def resize(self, size: int):
if len(self) > size:
self.data = random.sample(self.data, size)


class PrioritizedReplayBuffer(CircularReplayBuffer):
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implements a variant of https://arxiv.org/abs/1511.05952.
Instead of updating the TD error on the fly, we compute it for all samples
Expand Down Expand Up @@ -164,8 +190,34 @@ def batched_iter(
)
buffer = [self.data[i] for i in sampled_idcs]

# DEBUG
sampled_priorities = prio[sampled_idcs]
logger.debug(f"Average priority: {sampled_priorities.mean()}")

return self._batched_iter(buffer, batch_size, shuffle, infinite)


class ReplayBufferType(StrEnum):
ADD_ONLY = auto()
CIRCULAR = auto()
RANDOMIZED = auto()
PRIORITIZED = auto()


class ReplayBufferConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

buf_type: ReplayBufferType = Field(
default=ReplayBufferType.CIRCULAR, description="Circular is a good default."
)
size: int | None = None
kwargs: dict[str, Any] = Field(default_factory=dict)

def make_buffer(self, **kwargs) -> ReplayBuffer:
# kwargs are only for prioritized case
if self.buf_type == ReplayBufferType.ADD_ONLY:
return ReplayBuffer()
if self.buf_type == ReplayBufferType.CIRCULAR:
return CircularReplayBuffer()
if self.buf_type == ReplayBufferType.RANDOMIZED:
return RandomizedReplayBuffer()
if self.buf_type == ReplayBufferType.PRIORITIZED:
kwargs = self.kwargs | kwargs
return PrioritizedReplayBuffer(**kwargs)
typing.assert_never(self.buf_type)

0 comments on commit c0091e0

Please sign in to comment.