Skip to content

Commit

Permalink
fixed result struct
Browse files Browse the repository at this point in the history
  • Loading branch information
Aizen committed Apr 22, 2024
1 parent a5a4130 commit 7c8c0a7
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 29 deletions.
6 changes: 3 additions & 3 deletions _examples/other.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func main() {
mu := []float64{1.0, 2.0} // Initial mean vector
sigma := []float64{0.5, 0.5} // Initial standard deviation vector

optimizedMu, optimizedSigma := GoES.DefaultOpt(myCustomFunction, mu, sigma)
res, _ := GoES.DefaultOpt(myCustomFunction, mu, sigma)

fmt.Println("Optimized mean:", optimizedMu)
fmt.Println("Optimized standard deviation:", optimizedSigma)
fmt.Println("Optimized mean:", res.Mu)
fmt.Println("Optimized standard deviation:", res.Sigma)
}
4 changes: 2 additions & 2 deletions _examples/sphere.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func main() {
cfg.Verbose = false

// Perform optimization
optimizedMu, _ := GoES.Opt(sphere, mu, sigma, cfg)
res, _ := GoES.Opt(sphere, mu, sigma, cfg)

fmt.Println("Optimum:", optimizedMu) // should be close to vector [0, 1, 2, ..., dim-1]
fmt.Println("Optimum:", res.Mu) // should be close to vector [0, 1, 2, ..., dim-1]
}
32 changes: 27 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ This code implements a specific algorithm called CMA-ES (Covariance Matrix Adapt
Overall, this code provides an implementation of CMA-ES for optimizing a black-box function in Go. It allows you to specify your own objective function and configure various parameters for the optimization process.
*/
package goes
package GoES

import (
"fmt"
"log"
"math"
"slices"
Expand Down Expand Up @@ -69,12 +70,17 @@ func Defaults() Config {
return cfg
}

type Result struct {
Mu []float64
Sigma []float64
}

const const_Ez0 = 0.7978845608028661 // mean(abs(randn()))
func Opt(fn func([]float64) float64, mu []float64, sigma []float64, cfg Config) ([]float64, []float64) {
func Opt(fn func([]float64) float64, mu []float64, sigma []float64, cfg Config) (Result, error) {
pop_n := cfg.PopSize
n := len(mu)
if len(sigma) != n {
log.Panic("mu and sigma must have the same length.")
return Result{}, fmt.Errorf("mu (len %d) and sigma (len %d) must have the same length", len(mu), len(sigma))
}
for pop_n*pop_n <= 144*n {
pop_n++
Expand Down Expand Up @@ -149,10 +155,10 @@ func Opt(fn func([]float64) float64, mu []float64, sigma []float64, cfg Config)
log.Println("GoES: ", runs, mu, sigma, pop[pop_n/2].C)
}
}
return mu, sigma
return Result{Mu: mu, Sigma: sigma}, nil
}

func DefaultOpt(fn func([]float64) float64, mu []float64, sigma []float64) ([]float64, []float64) {
func DefaultOpt(fn func([]float64) float64, mu []float64, sigma []float64) (Result, error) {
cfg := Defaults()
cfg.Generations = int(math.Ceil(math.Sqrt(float64(len(mu)*2+1)) * 300))
return Opt(fn, mu, sigma, cfg)
Expand All @@ -174,3 +180,19 @@ func makeWeights(pop_size int) []float64 {
}
return W
}

func Positive(z float64) float64 {
if z < 0 {
return 1 / (1 - z)
}
return z + 1
}

func Probability(z float64) float64 {
p := Positive(z)
return p / (1 + p)
}

func Bounded(x, a, b float64) float64 {
return Probability(x)*(b-a) + a
}
38 changes: 25 additions & 13 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package goes
package GoES

import (
"bytes"
Expand All @@ -22,8 +22,11 @@ func cost_test(ince float64, iva_detratta float64) float64 {
pen := impo*perc_pay - ince_l
return abs2(pen)
}
mu, _ := DefaultOpt(cost, []float64{2 * ince}, []float64{ince / 10})
return mu[0]
sol, err := DefaultOpt(cost, []float64{ince * 0.9}, []float64{ince / 10})
if err != nil {
return math.NaN()
}
return sol.Mu[0]
}

func TestUni(t *testing.T) {
Expand All @@ -42,9 +45,14 @@ func TestUni(t *testing.T) {

func TestBi(t *testing.T) {
muw := []float64{4, -3}
mu, sig := DefaultOpt(func(f []float64) float64 {
sol, err_opt := DefaultOpt(func(f []float64) float64 {
return abs2(f[0]-muw[0]) + 100.0*abs2(f[0]+f[1]-muw[0]-muw[1])
}, []float64{0.0, 0.0}, []float64{1.0, 1.0})
if err_opt != nil {
t.Error(err_opt)
}
mu := sol.Mu
sig := sol.Sigma
err := math.Sqrt(abs2((mu[0]-muw[0])/muw[0]) + abs2((mu[1]-muw[1])/muw[1]))
if err > 1e-6 {
t.Error("got: ", mu, sig, " wanted:", muw, " error:", err)
Expand All @@ -66,17 +74,21 @@ func TestVerbose(t *testing.T) {
}
}

func TestPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()

// The following is the code under test for panicking
DefaultOpt(
func TestError(t *testing.T) {
_, err := DefaultOpt(
func(f []float64) float64 { return 0.0 },
[]float64{0.0, 0.0},
[]float64{1.0}, // here there is a missing element
)
if err == nil {
t.Error("Expected error")
}
_, err = DefaultOpt(
func(f []float64) float64 { return 0.0 },
[]float64{0.0, 0.0},
[]float64{1.0, 1.0}, // here there is a missing element
)
if err != nil {
t.Error("Expected success")
}
}
11 changes: 5 additions & 6 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ func main() {
cfg.Verbose = false

// Perform optimization
optimizedMu, _ := GoES.Opt(sphere, mu, sigma, cfg)
res, _ := GoES.Opt(sphere, mu, sigma, cfg)

fmt.Println("Optimum:", optimizedMu) // should be close to vector [0, 1, 2, ..., dim-1]
fmt.Println("Optimum:", res.Mu) // should be close to vector [0, 1, 2, ..., dim-1]
}

```

**Example 2: Default Optimization Config**
Expand Down Expand Up @@ -98,9 +97,9 @@ func main() {
mu := []float64{1.0, 2.0} // Initial mean vector
sigma := []float64{0.5, 0.5} // Initial standard deviation vector

optimizedMu, optimizedSigma := GoES.DefaultOpt(myCustomFunction, mu, sigma)
res, _ := GoES.DefaultOpt(myCustomFunction, mu, sigma)

fmt.Println("Optimized mean:", optimizedMu)
fmt.Println("Optimized standard deviation:", optimizedSigma)
fmt.Println("Optimized mean:", res.Mu)
fmt.Println("Optimized standard deviation:", res.Sigma)
}
```

0 comments on commit 7c8c0a7

Please sign in to comment.