Skip to content

Commit

Permalink
Add method Regression.GetCoeffs (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
updogliu authored and mish15 committed Nov 19, 2019
1 parent 24a553f commit d629f2e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
12 changes: 12 additions & 0 deletions regression.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@ func (r *Regression) Coeff(i int) float64 {
return r.coeff[i]
}

// GetCoeffs returns the calculated coefficients. The element at index 0 is the offset.
func (r *Regression) GetCoeffs() []float64 {
if len(r.coeff) == 0 {
return nil
}
coeffs := make([]float64, len(r.coeff))
for i := range coeffs {
coeffs[i] = r.coeff[i]
}
return coeffs
}

func (r *Regression) calcPredicted() string {
observations := len(r.data)
var predicted float64
Expand Down
30 changes: 30 additions & 0 deletions regression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package regression

import (
"fmt"
"math"
"testing"
)

Expand Down Expand Up @@ -154,5 +155,34 @@ func TestMakeDataPoints(t *testing.T) {
t.Error("Expected observed to be the same as the index")
}
}
}

func TestGetCoeffs(t *testing.T) {
a := [][]float64{
{651, 1, 23},
{762, 2, 26},
{856, 3, 30},
{1063, 4, 34},
{1190, 5, 43},
{1298, 6, 48},
{1421, 7, 52},
{1440, 8, 57},
{1518, 9, 58},
}

r := new(Regression)
r.Train(MakeDataPoints(a, 0)...)
r.Run()

coeffs := r.GetCoeffs()
if len(coeffs) != 3 {
t.Errorf("Expected 3 coefficients. Got %v instead", len(coeffs))
}

expected := []float64{323.54, 46.60, 13.99}
for i := range expected {
if math.Abs(expected[i]-coeffs[i]) > 0.01 {
t.Errorf("Expected coefficient %v to be %v. Got %v instead", i, expected[i], coeffs[i])
}
}
}

0 comments on commit d629f2e

Please sign in to comment.