diff --git a/cmaes/_warm_start.py b/cmaes/_warm_start.py index fb8ab56..a915bdf 100644 --- a/cmaes/_warm_start.py +++ b/cmaes/_warm_start.py @@ -37,7 +37,7 @@ def get_warm_start_mgd( source_solutions = sorted(source_solutions, key=lambda t: t[1]) gamma_n = math.floor(len(source_solutions) * gamma) assert gamma_n >= 1, "One or more solutions must be selected from a source task" - dim = len(source_solutions[0]) + dim = len(source_solutions[0][0]) top_gamma_solutions = np.empty( shape=( gamma_n, diff --git a/tests/test_warm_start.py b/tests/test_warm_start.py new file mode 100644 index 0000000..3f747a3 --- /dev/null +++ b/tests/test_warm_start.py @@ -0,0 +1,12 @@ +import numpy as np +from unittest import TestCase +from cmaes import CMA, get_warm_start_mgd + + +class TestWarmStartCMA(TestCase): + def test_dimension(self): + optimizer = CMA(mean=np.zeros(10), sigma=1.3) + source_solutions = [(optimizer.ask(), 0.0) for _ in range(100)] + ws_mean, ws_sigma, ws_cov = get_warm_start_mgd(source_solutions) + + self.assertEqual(ws_mean.size, 10)