-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathpolicy.go
88 lines (72 loc) · 2.33 KB
/
policy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package her
import (
agentv1 "github.com/aunum/gold/pkg/v1/agent"
"github.com/aunum/gold/pkg/v1/dense"
envv1 "github.com/aunum/gold/pkg/v1/env"
"github.com/aunum/goro/pkg/v1/layer"
modelv1 "github.com/aunum/goro/pkg/v1/model"
"github.com/aunum/log"
g "gorgonia.org/gorgonia"
)
// PolicyConfig are the hyperparameters for a policy.
type PolicyConfig struct {
// Loss function to evaluate network performance.
Loss modelv1.Loss
// Optimizer to optimize the weights with regards to the error.
Optimizer g.Solver
// LayerBuilder is a builder of layer.
LayerBuilder LayerBuilder
// Batch size to train on.
BatchSize int
// Track is whether to track the model.
Track bool
}
// DefaultPolicyConfig are the default hyperparameters for a policy.
var DefaultPolicyConfig = &PolicyConfig{
Loss: modelv1.MSE,
Optimizer: g.NewAdamSolver(g.WithBatchSize(128), g.WithLearnRate(0.0005)),
LayerBuilder: DefaultFCLayerBuilder,
BatchSize: 128,
Track: true,
}
// LayerBuilder builds layers.
type LayerBuilder func(x, y *modelv1.Input) []layer.Config
// DefaultFCLayerBuilder is a default fully connected layer builder.
var DefaultFCLayerBuilder = func(x, y *modelv1.Input) []layer.Config {
return []layer.Config{
layer.FC{Input: x.Squeeze()[0], Output: 512},
layer.FC{Input: 512, Output: 512},
layer.FC{Input: 512, Output: y.Squeeze()[0], Activation: layer.Linear},
}
}
// MakePolicy makes a model.
func MakePolicy(name string, config *PolicyConfig, base *agentv1.Base, env *envv1.Env) (modelv1.Model, error) {
stateGoalShape := dense.MulShape(dense.SqueezeShape(env.ObservationSpaceShape()), 2)
x := modelv1.NewInput("stateGoal", stateGoalShape)
x.EnsureBatch()
y := modelv1.NewInput("actionPotentials", envv1.PotentialsShape(env.ActionSpace))
y.EnsureBatch()
log.Debugv("x shape", x.Shape())
log.Debugv("y shape", y.Shape())
model, err := modelv1.NewSequential(name)
if err != nil {
return nil, err
}
model.AddLayers(config.LayerBuilder(x, y)...)
opts := modelv1.NewOpts()
opts.Add(
modelv1.WithOptimizer(config.Optimizer),
modelv1.WithLoss(config.Loss),
modelv1.WithBatchSize(config.BatchSize),
)
if config.Track {
opts.Add(modelv1.WithTracker(base.Tracker))
} else {
opts.Add(modelv1.WithoutTracker())
}
err = model.Compile(x, y, opts.Values()...)
if err != nil {
return nil, err
}
return model, nil
}