Skip to content

Commit

Permalink
fix: invalid buffer reset
Browse files Browse the repository at this point in the history
Signed-off-by: Anirudh <anirudh@semiotic.ai>
  • Loading branch information
anirudh2 committed Mar 7, 2023
1 parent 8135e20 commit 4830323
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions algorithmconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def config():
"actiondistribution": {
"kind": "gaussian",
"initial_mean": [1.0],
"initial_stddev": [0.1],
"initial_stddev": [0.5],
"minmean": [0.0],
"maxmean": [2.0],
"minstddev": [0.1],
"maxstddev": [1.0],
},
"optimizer": {"kind": "sgd", "lr": 0.001},
"optimizer": {"kind": "sgd", "lr": 0.01},
"ppoiterations": 2,
"epsclip": 0.1,
"entropycoeff": 1e-1,
Expand Down
13 changes: 8 additions & 5 deletions autoagora_agents/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ def _entropyloss(self) -> torch.Tensor:
"""Penalise high entropies."""
return -self.actiondist.entropy() * self.entropycoeff

def _update(self):
def _update(self) -> bool:
if not buffer.isfull(self.buffer):
return
return False
super().update()

rewards = buffer.get("reward", self.buffer)
Expand Down Expand Up @@ -369,9 +369,12 @@ def _update(self):
torch.sum(loss).backward()
self.opt.step()

return True

def update(self):
self._update()
self.buffer.clear()
ran = self._update()
if ran:
self.buffer.clear()


# NOTE: This is experimental. Do not use!
Expand Down Expand Up @@ -423,7 +426,7 @@ def logprob(self, _):
return buffer.get("logprob", self.buffer).unsqueeze(dim=1)

def update(self):
self._update()
_ = self._update()


def algorithmgroupfactory(*, kind: str, count: int, **kwargs) -> list[Algorithm]:
Expand Down
2 changes: 1 addition & 1 deletion simulationconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@simulation_ingredient.config
def config():
nproducts = 1
ntimesteps = 100
ntimesteps = 1000
nepisodes = 1
distributor = {"kind": "softmax", "source": "consumer", "to": "indexer"}
entities = [
Expand Down

0 comments on commit 4830323

Please sign in to comment.