Skip to content

Commit

Permalink
Merge pull request #98 from CyberAgent/fix-ws-cma-es-dim
Browse files Browse the repository at this point in the history
Fix dimensions of Warm starting CMA-ES
  • Loading branch information
c-bata authored Feb 19, 2021
2 parents 0da22f5 + 66cbb7c commit 05ef85c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cmaes/_warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_warm_start_mgd(
source_solutions = sorted(source_solutions, key=lambda t: t[1])
gamma_n = math.floor(len(source_solutions) * gamma)
assert gamma_n >= 1, "One or more solutions must be selected from a source task"
dim = len(source_solutions[0])
dim = len(source_solutions[0][0])
top_gamma_solutions = np.empty(
shape=(
gamma_n,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_warm_start.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np
from unittest import TestCase
from cmaes import CMA, get_warm_start_mgd


class TestWarmStartCMA(TestCase):
def test_dimension(self):
optimizer = CMA(mean=np.zeros(10), sigma=1.3)
source_solutions = [(optimizer.ask(), 0.0) for _ in range(100)]
ws_mean, ws_sigma, ws_cov = get_warm_start_mgd(source_solutions)

self.assertEqual(ws_mean.size, 10)

0 comments on commit 05ef85c

Please sign in to comment.