Skip to content

Commit

Permalink
Fix numerical issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Jan 20, 2025
1 parent 1677b97 commit e2da469
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 50 deletions.
38 changes: 38 additions & 0 deletions ext/stats/moments.nogo
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package stats

import "math"

type moment struct {
m1, m2, m3, m4 kahan
n int64
}

func (w *moment) enqueue(x float64) {
n := w.n + 1
w.n = n
y := x - w.m1.hi - w.m1.lo
w.m1.add(y / float64(n))
y = math.FMA(y, x, -w.m2.hi) - w.m2.lo
w.m2.add(y / float64(n))
y = math.FMA(y, x, -w.m3.hi) - w.m3.lo
w.m3.add(y / float64(n))
y = math.FMA(y, x, -w.m4.hi) - w.m4.lo
w.m4.add(y / float64(n))
}

func (w *moment) dequeue(x float64) {
n := w.n - 1
if n <= 0 {
*w = moment{}
return
}
w.n = n
y := x - w.m1.hi + w.m1.lo
w.m1.sub(y / float64(n))
y = math.FMA(y, x, w.m2.hi) + w.m2.lo
w.m2.sub(y / float64(n))
y = math.FMA(y, x, w.m3.hi) + w.m3.lo
w.m3.sub(y / float64(n))
y = math.FMA(y, x, w.m4.hi) + w.m4.lo
w.m4.sub(y / float64(n))
}
3 changes: 3 additions & 0 deletions ext/stats/percentile.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import (
"github.com/ncruces/sort/quick"
)

// Compatible with:
// https://sqlite.org/src/file/ext/misc/percentile.c

const (
median = iota
percentile_100
Expand Down
38 changes: 33 additions & 5 deletions ext/stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
// - regr_count: count non-null pairs of variables
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
// - regr_json: all regr stats in a JSON object
// - regr_json: all regr stats as a JSON object
// - percentile_disc: discrete quantile
// - percentile_cont: continuous quantile
// - percentile: continuous percentile
Expand Down Expand Up @@ -111,6 +111,17 @@ type variance struct {
}

func (fn *variance) Value(ctx sqlite3.Context) {
switch fn.n {
case 1:
switch fn.kind {
case var_pop, stddev_pop:
ctx.ResultFloat(0)
}
return
case 0:
return
}

var r float64
switch fn.kind {
case var_pop:
Expand Down Expand Up @@ -151,6 +162,25 @@ type covariance struct {
}

func (fn *covariance) Value(ctx sqlite3.Context) {
if fn.kind == regr_count {
ctx.ResultInt64(fn.regr_count())
return
}
switch fn.n {
case 1:
switch fn.kind {
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
ctx.ResultFloat(0)
return
case regr_avgx, regr_avgy:
break
default:
return
}
case 0:
return
}

var r float64
switch fn.kind {
case var_pop:
Expand All @@ -175,11 +205,9 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
r = fn.regr_slope()
case regr_intercept:
r = fn.regr_intercept()
case regr_count:
ctx.ResultInt64(fn.regr_count())
return
case regr_json:
ctx.ResultText(fn.regr_json())
var buf [128]byte
ctx.ResultRawText(fn.regr_json(buf[:0]))
return
}
ctx.ResultFloat(r)
Expand Down
45 changes: 41 additions & 4 deletions ext/stats/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,23 @@ func TestRegister_variance(t *testing.T) {
t.Fatal(err)
}

stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
stmt.Close()

err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}

stmt, _, err := db.Prepare(`
stmt, _, err = db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
Expand Down Expand Up @@ -65,7 +76,11 @@ func TestRegister_variance(t *testing.T) {
}
stmt.Close()

stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
stmt, _, err = db.Prepare(`
SELECT
var_samp(x) OVER (ROWS 1 PRECEDING),
var_pop(x) OVER (ROWS 1 PRECEDING)
FROM data`)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -96,12 +111,26 @@ func TestRegister_covariance(t *testing.T) {
t.Fatal(err)
}

stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 0", got)
}
if got := stmt.ColumnType(1); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
stmt.Close()

err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}

stmt, _, err := db.Prepare(`SELECT
stmt, _, err = db.Prepare(`SELECT
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
Expand Down Expand Up @@ -157,7 +186,12 @@ func TestRegister_covariance(t *testing.T) {
}
stmt.Close()

stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
stmt, _, err = db.Prepare(`
SELECT
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
FROM data`)
if err != nil {
t.Fatal(err)
}
Expand All @@ -171,6 +205,9 @@ func TestRegister_covariance(t *testing.T) {
t.Errorf("got %v, want %v", got, want[i])
}
}
if stmt.Err() != nil {
t.Fatal(stmt.Err())
}
stmt.Close()
}

Expand Down
83 changes: 45 additions & 38 deletions ext/stats/welford.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ package stats
import (
"math"
"strconv"
"strings"

"github.com/ncruces/go-sqlite3/internal/util"
)

// Welford's algorithm with Kahan summation:
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm

// See also:
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates

type welford struct {
m1, m2 kahan
n int64
Expand All @@ -39,17 +38,23 @@ func (w welford) stddev_samp() float64 {
}

func (w *welford) enqueue(x float64) {
w.n++
n := w.n + 1
w.n = n
d1 := x - w.m1.hi - w.m1.lo
w.m1.add(d1 / float64(w.n))
w.m1.add(d1 / float64(n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.add(d1 * d2)
}

func (w *welford) dequeue(x float64) {
w.n--
n := w.n - 1
if n <= 0 {
*w = welford{}
return
}
w.n = n
d1 := x - w.m1.hi - w.m1.lo
w.m1.sub(d1 / float64(w.n))
w.m1.sub(d1 / float64(n))
d2 := x - w.m1.hi - w.m1.lo
w.m2.sub(d1 * d2)
}
Expand Down Expand Up @@ -112,38 +117,35 @@ func (w welford2) regr_r2() float64 {
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
}

func (w welford2) regr_json() string {
var json strings.Builder
var num [32]byte
json.Grow(128)
json.WriteString(`{"count":`)
json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10))
json.WriteString(`,"avgy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64))
json.WriteString(`,"avgx":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64))
json.WriteString(`,"syy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64))
json.WriteString(`,"sxx":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64))
json.WriteString(`,"sxy":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64))
json.WriteString(`,"slope":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64))
json.WriteString(`,"intercept":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64))
json.WriteString(`,"r2":`)
json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64))
json.WriteByte('}')
return json.String()
func (w welford2) regr_json(dst []byte) []byte {
dst = append(dst, `{"count":`...)
dst = strconv.AppendInt(dst, w.regr_count(), 10)
dst = append(dst, `,"avgy":`...)
dst = util.AppendNumber(dst, w.regr_avgy())
dst = append(dst, `,"avgx":`...)
dst = util.AppendNumber(dst, w.regr_avgx())
dst = append(dst, `,"syy":`...)
dst = util.AppendNumber(dst, w.regr_syy())
dst = append(dst, `,"sxx":`...)
dst = util.AppendNumber(dst, w.regr_sxx())
dst = append(dst, `,"sxy":`...)
dst = util.AppendNumber(dst, w.regr_sxy())
dst = append(dst, `,"slope":`...)
dst = util.AppendNumber(dst, w.regr_slope())
dst = append(dst, `,"intercept":`...)
dst = util.AppendNumber(dst, w.regr_intercept())
dst = append(dst, `,"r2":`...)
dst = util.AppendNumber(dst, w.regr_r2())
return append(dst, '}')
}

func (w *welford2) enqueue(y, x float64) {
w.n++
n := w.n + 1
w.n = n
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.add(d1y / float64(w.n))
w.m1x.add(d1x / float64(w.n))
w.m1y.add(d1y / float64(n))
w.m1x.add(d1x / float64(n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.add(d1y * d2y)
Expand All @@ -152,11 +154,16 @@ func (w *welford2) enqueue(y, x float64) {
}

func (w *welford2) dequeue(y, x float64) {
w.n--
n := w.n - 1
if n <= 0 {
*w = welford2{}
return
}
w.n = n
d1y := y - w.m1y.hi - w.m1y.lo
d1x := x - w.m1x.hi - w.m1x.lo
w.m1y.sub(d1y / float64(w.n))
w.m1x.sub(d1x / float64(w.n))
w.m1y.sub(d1y / float64(n))
w.m1x.sub(d1x / float64(n))
d2y := y - w.m1y.hi - w.m1y.lo
d2x := x - w.m1x.hi - w.m1x.lo
w.m2y.sub(d1y * d2y)
Expand Down
22 changes: 22 additions & 0 deletions ext/stats/welford_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ func Test_welford(t *testing.T) {
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}

s1.dequeue(16)
s1.dequeue(7)
s1.dequeue(13)
s1.enqueue(16)
s1.enqueue(7)
s1.enqueue(13)
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
}

func Test_covar(t *testing.T) {
Expand Down Expand Up @@ -65,6 +75,18 @@ func Test_covar(t *testing.T) {
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}

c1.dequeue(2, 60)
c1.dequeue(5, 80)
c1.dequeue(4, 75)
c1.dequeue(7, 90)
c1.enqueue(2, 60)
c1.enqueue(5, 80)
c1.enqueue(4, 75)
c1.enqueue(7, 90)
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
}

func Test_correlation(t *testing.T) {
Expand Down
Loading

0 comments on commit e2da469

Please sign in to comment.