Skip to content

Commit

Permalink
Merge pull request #4 from ChristopherRabotin/issue-3
Browse files Browse the repository at this point in the history
Fix RK4 implementation errors
  • Loading branch information
ChristopherRabotin authored Jan 18, 2017
2 parents e5144d7 + fc9651b commit 0502445
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 56 deletions.
4 changes: 2 additions & 2 deletions examples/angularMomentum/attitude.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const (
)

var (
eye = mat64.NewDense(3, 3, []float64{1, 0, 0, 0, 0, 0, 0, 0, 1})
eye = mat64.NewDense(3, 3, []float64{1, 0, 0, 0, 1, 0, 0, 0, 1})
)

/*-----*/
Expand Down Expand Up @@ -132,7 +132,7 @@ func (a *Attitude) GetState() []float64 {
}

// SetState sets the state of this attitude for the EOM as defined below.
func (a *Attitude) SetState(i uint64, s []float64) {
func (a *Attitude) SetState(t float64, s []float64) {
a.Attitude.s1 = s[0]
a.Attitude.s2 = s[1]
a.Attitude.s3 = s[2]
Expand Down
4 changes: 2 additions & 2 deletions integrable.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package ode
// WARNING: Implementation must manage its own state based on the iteration.
type Integrable interface {
GetState() []float64 // Get the latest state of this integrable.
SetState(i uint64, s []float64) // Set the state s of a given iteration i.
Stop(i uint64) bool // Return whether to stop the integration from iteration i.
SetState(t float64, s []float64) // Set the state s of a given time t.
Stop(t float64) bool // Return whether to stop the integration at time t.
Func(t float64, s []float64) []float64 // ODE function from time t and state s, must return a new state.
}
53 changes: 24 additions & 29 deletions rk4.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,40 @@ func NewRK4(x0 float64, stepSize float64, inte Integrable) (r *RK4) {
// Solve solves the configured RK4.
// Returns the number of iterations performed and the last X_i, or an error.
func (r *RK4) Solve() (uint64, float64, error) {
const (
half = 1 / 2.0
oneSixth = 1 / 6.0
oneThird = 1 / 3.0
)

iterNum := uint64(0)
xi := r.X0
for !r.Integator.Stop(iterNum) {
halfStep := xi * half
for !r.Integator.Stop(xi) {
halfStep := r.StepSize * 0.5
state := r.Integator.GetState()
newState := make([]float64, len(state))
//k1, k2, k3, k4 are used as buffers AND result variables.
k1 := make([]float64, len(state))
k2 := make([]float64, len(state))
k3 := make([]float64, len(state))
k4 := make([]float64, len(state))
tState := make([]float64, len(state))
z := make([]float64, len(state)) // a temporary variable

// Compute the k's.
for i, y := range r.Integator.Func(xi, state) {
k1[i] = y * r.StepSize
tState[i] = state[i] + k1[i]*half
}
for i, y := range r.Integator.Func(xi+halfStep, tState) {
k2[i] = y * r.StepSize
tState[i] = state[i] + k2[i]*half
// Step 1
f1 := r.Integator.Func(xi, state)

// Step 2
for i := 0; i < len(state); i++ {
z[i] = state[i] + halfStep*f1[i]
}
for i, y := range r.Integator.Func(xi+halfStep, tState) {
k3[i] = y * r.StepSize
tState[i] = state[i] + k3[i]
f2 := r.Integator.Func(xi+halfStep, z)

// Step 3
for i := 0; i < len(state); i++ {
z[i] = state[i] + halfStep*f2[i]
}
for i, y := range r.Integator.Func(xi+halfStep, tState) {
k4[i] = y * r.StepSize
newState[i] = state[i] + oneSixth*(k1[i]+k4[i]) + oneThird*(k2[i]+k3[i])
f3 := r.Integator.Func(xi+halfStep, z)

// Step 4
for i := 0; i < len(state); i++ {
z[i] = state[i] + r.StepSize*f3[i]
}
r.Integator.SetState(iterNum, newState)
f4 := r.Integator.Func(xi+r.StepSize, z)

for i := 0; i < len(state); i++ {
newState[i] = state[i] + r.StepSize*(f1[i]+2*f2[i]+2*f3[i]+f4[i])/6
}
xi += r.StepSize
r.Integator.SetState(xi, newState)
iterNum++ // Don't forget to increment the number of iterations.
}

Expand Down
176 changes: 153 additions & 23 deletions rk4_test.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,47 @@
package ode

import (
"fmt"
"math"
"testing"

"github.com/ChristopherRabotin/ode/examples/angularMomentum"
"github.com/gonum/floats"
)

const (
tolerance = 1e-10
)

func TestPanics(t *testing.T) {
assertPanic(t, "negative step", func() {
NewRK4(1, -1, nil)
})

assertPanic(t, "nil integrator", func() {
NewRK4(1, 1, nil)
})
}

type Balbasi1D struct {
state []float64 // Note that we don't have a state history here.
prevIt uint
state []float64
}

func NewBalbasi1D() (b *Balbasi1D) {
b = &Balbasi1D{}
b.state = []float64{1200.0}
b.prevIt = 0
return
}

func (b *Balbasi1D) GetState() []float64 {
return b.state
}

func (b *Balbasi1D) SetState(i uint64, s []float64) {
func (b *Balbasi1D) SetState(t float64, s []float64) {
b.state = s
if i != 0 && b.prevIt+1 != uint(i) {
panic(fmt.Errorf("expected i=%d, got i=%d", b.prevIt+1, i))
}
b.prevIt = uint(i)
}

func (b *Balbasi1D) Stop(i uint64) bool {
return i*30 >= 480
func (b *Balbasi1D) Stop(t float64) bool {
return t > 480
}

func (b *Balbasi1D) Func(t float64, s []float64) []float64 {
Expand All @@ -56,8 +64,9 @@ type AttitudeTest struct {
*dynamics.Attitude
}

func (a *AttitudeTest) Stop(i uint64) bool {
return float64(i)*1e-6 >= 1e-1
func (a *AttitudeTest) Stop(t float64) bool {
// Propagate for 0.1 seconds.
return t > 1e-1
}

func NewAttitudeTest() (a *AttitudeTest) {
Expand All @@ -69,20 +78,141 @@ func NewAttitudeTest() (a *AttitudeTest) {
func TestRK4Attitude(t *testing.T) {
inte := NewAttitudeTest()
initMom := inte.Momentum()
if _, _, err := NewRK4(0, 1e-6, inte).Solve(); err != nil {
for _, step := range []float64{1e-6, 1e-8} {
if _, _, err := NewRK4(0, step, inte).Solve(); err != nil {
t.Fatalf("err: %+v\n", err)
}
if diff := math.Abs(initMom - inte.Momentum()); diff > 1e-8 {
t.Fatalf("angular momentum changed by %4.12f", diff)
}
}
}

type V1DSimple struct {
state []float64
}

func (v *V1DSimple) GetState() []float64 {
return v.state
}

func (v *V1DSimple) SetState(t float64, s []float64) {
v.state = s
}

func (v *V1DSimple) Stop(t float64) bool {
return floats.EqualWithinAbs(t, 1, tolerance)
}

func (v *V1DSimple) Func(x float64, y []float64) []float64 {
return []float64{-2 * y[0]}
}

func TestRK4Simple1D(t *testing.T) {
inte := new(V1DSimple)
inte.state = []float64{1}
iterNum, xi, err := NewRK4(0, 0.1, inte).Solve()
if err != nil {
t.Fatalf("err: %+v\n", err)
}
if diff := math.Abs(initMom - inte.Momentum()); diff > 1e-8 {
t.Fatalf("angular momentum changed by %4.12f", diff)
if !floats.EqualWithinAbs(xi, 1, tolerance) {
t.Fatalf("xi=%f != 1.0", xi)
}
if iterNum != 10 {
t.Fatalf("iterNum=%d != 10", iterNum)
}
exp := 0.1353395484305101
if !floats.EqualWithinAbs(inte.GetState()[0], exp, tolerance) {
t.Fatalf("\nstate=%f\n exp=%f", inte.GetState()[0], exp)
}
}

func TestPanics(t *testing.T) {
assertPanic(t, "negative step", func() {
NewRK4(1, -1, nil)
})
type VSimple struct {
state []float64
}

assertPanic(t, "nil integrator", func() {
NewRK4(1, 1, nil)
})
func (v *VSimple) GetState() []float64 {
return v.state
}

func (v *VSimple) SetState(t float64, s []float64) {
v.state = s
}

func (v *VSimple) Stop(t float64) bool {
return t >= 37.8
}

func (v *VSimple) Func(x float64, y []float64) []float64 {
return []float64{y[1], -y[0]}
}

func TestRK4Simple(t *testing.T) {
inte := new(VSimple)
inte.state = []float64{0, 1}
iterNum, xi, err := NewRK4(0, 0.2, inte).Solve()
if err != nil {
t.Fatalf("err: %+v\n", err)
}
if xi != 37.8 {
t.Fatalf("xi=%f != 37.8", xi)
}
if iterNum != 189 {
t.Fatalf("iterNum=%d != 189", iterNum)
}
exp := []float64{+1.0021441571397413e-01, +9.9488186473553231e-01}
state := inte.GetState()
if exp[0] != state[0] || exp[1] != state[1] {
t.Fatalf("\nstate=%+v\n exp=%+v", state, exp)
}
}

type KraichnanOrszag struct {
steps uint64
state []float64
}

func (v *KraichnanOrszag) GetState() []float64 {
return v.state
}

func (v *KraichnanOrszag) SetState(t float64, s []float64) {
v.state = s
}

func (v *KraichnanOrszag) Stop(t float64) bool {
return t >= float64(v.steps)*0.01
}

func (v *KraichnanOrszag) Func(x float64, y []float64) []float64 {
return []float64{y[0] * y[2], -y[1] * y[2], -y[0]*y[0] + y[1]*y[1]}
}

func TestRK4KO(t *testing.T) {
for _, steps := range []uint64{30, 3000} {
inte := &KraichnanOrszag{steps, []float64{1, .4, .2}}
iterNum, xi, err := NewRK4(0, 0.01, inte).Solve()
if err != nil {
t.Fatalf("err: %+v\n", err)
}
if iterNum != steps {
t.Fatalf("iterNum=%d != %d", iterNum, steps)
}
var exp []float64
if steps == 30 {
if !floats.EqualWithinAbs(xi, 0.3, tolerance) {
t.Fatalf("xi=%.16f != 0.3", xi)
}
exp = []float64{1.0209861554390987e+00, 3.9177808423412647e-01, -6.4009398764259762e-02}
} else {
if !floats.EqualWithinAbs(xi, 30, tolerance) {
t.Fatalf("xi=%.16f != 30", xi)
}
exp = []float64{4.8745696934931565e-01, 8.2058525186654796e-01, -5.3761096276439480e-01}
}
state := inte.GetState()
if !floats.EqualApprox(exp, state, tolerance) {
t.Fatalf("\nstate=%+v\n exp=%+v", state, exp)
}
}
}

0 comments on commit 0502445

Please sign in to comment.