-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprioritized_replay_buffer.py
44 lines (37 loc) · 1.59 KB
/
prioritized_replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
class PrioritizedReplayBuffer:
"""
implementation of a prioritized experience replay buffer for reinforcement learning. It dynamically stores and
samples experiences with respect to their priority, enabling more effective learning by focusing on important
transitions while managing memory usage with its configurable capacity.
"""
def __init__(self, capacity, alpha=0.6, beta=0.4):
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.buffer = []
self.priorities = []
self.pos = 0
def append(self, experience):
max_priority = max(self.priorities) if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.priorities.append(max_priority)
else:
self.buffer[self.pos] = experience
self.priorities[self.pos] = max_priority
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size):
priorities = np.array(self.priorities)
probs = priorities ** self.alpha
probs /= probs.sum()
indices = np.random.choice(len(self.buffer), batch_size, p=probs)
samples = [self.buffer[idx] for idx in indices]
weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
weights /= weights.max()
return samples, indices, weights
def update_priorities(self, indices, priorities):
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def __len__(self):
return len(self.buffer)