-
Notifications
You must be signed in to change notification settings - Fork 12
/
num.go
284 lines (239 loc) · 6.37 KB
/
num.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
package tds
import (
"database/sql/driver"
"fmt"
"math"
"math/big"
"reflect"
"strconv"
"strings"
"sync"
"github.com/thda/tds/binary"
"errors"
)
// implement the numeric/nullnumeric data types, and provide Scanner/Valuer for it
// as well as serialization/deserialization methods.
// Num represents a sybase numeric data type
type Num struct {
r big.Rat
precision int8
scale int8
isNull bool
}
// initialise all the possible exponents for numeric datatypes
// Will be mainly used to check for overflow.
var decimalPowers [41]*big.Int
func init() {
for i := int64(0); i <= 40; i++ {
decimalPowers[i] = new(big.Int).Exp(big.NewInt(10), big.NewInt(i), nil)
}
}
//
// Pools for performance
//
// big.Rat free list
var rPool = sync.Pool{
New: func() interface{} { return new(big.Rat) },
}
// numm free list
var numPool = sync.Pool{
New: func() interface{} { return new(Num) },
}
//
// Scanner and Valuer to satisfy database/sql interfaces
//
// Scan implements the Scanner interface.
// Allows initiating a tds.Num from a string, or any golang numeric type.
// When providing a string, it must be in decimal form,
// with an optional sign, ie -50.40
// The dot is the separator.
//
// Example:
//
// num := Num{precision: p, scale: s}
// num.Scan("-10.4")
//
// A loss of precision should alway cause an error (except for bugs, of course).
func (n *Num) Scan(src interface{}) error {
// use string as an intermediate
var strVal string
var ok bool
if strVal, ok = src.(string); !ok {
if src == nil {
n.isNull = true
return nil
}
rv := reflect.ValueOf(src)
switch rv.Kind() {
default:
return errors.New("unexpected type for numeric scan")
case reflect.Ptr:
if rv.IsNil() {
n.isNull = true
return nil
}
return n.Scan(rv.Elem().Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
strVal = strconv.FormatInt(rv.Int(), 10)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
strVal = strconv.FormatUint(rv.Uint(), 10)
case reflect.Float64:
strVal = strconv.FormatFloat(rv.Float(), 'g', -1, 64)
case reflect.Float32:
strVal = strconv.FormatFloat(rv.Float(), 'g', -1, 32)
}
}
if _, ok = n.r.SetString(strVal); !ok {
return fmt.Errorf("tds: could not parse string %s to number", strVal)
}
// check for loss of precision
mul := rPool.Get().(*big.Rat).SetInt(decimalPowers[n.scale])
mul.Mul(mul, &n.r)
if !mul.IsInt() {
return ErrOverFlow
}
return nil
}
// implement the stringer interface
func (n Num) String() string {
// shortcuts for ints
if n.r.IsInt() {
b := []byte(n.r.String())
return string(b[:len(b)-2])
}
mul := rPool.Get().(*big.Rat).SetInt(decimalPowers[n.scale])
mul.Mul(&n.r, mul)
defer rPool.Put(mul)
if !mul.IsInt() {
return "incorrect rational"
}
// get the integer value, and sign
sign := mul.Sign()
str := strings.Split(mul.Abs(mul).String(), "/")[0]
// left pad with zeroes
cnt := int(n.scale) - len(str) + 1
if cnt < 0 {
cnt = 0
}
str = strings.Repeat("0", cnt) + str
if sign == -1 {
str = "-" + str
}
return strings.Trim(string(str[:len(str)-int(n.scale)])+
"."+string(str[len(str)-
int(n.scale):]), "0")
}
// Rat returns the underlying big.Rat value
func (n Num) Rat() big.Rat {
return n.r
}
// numConverter just checks for overflows
// Right now you can only give time.Time and *time.Time parameters
type numConverter struct {
precision int8
scale int8
}
// ConvertValue will convert to an array of bytes, the first two being the precision and scale
func (nc numConverter) ConvertValue(src interface{}) (driver.Value, error) {
var err error
if src == nil {
return nil, nil
}
// is numeric?
if num, ok := src.(Num); ok {
if num.isNull {
return nil, nil
}
// check for loss of precision
if num.precision > nc.precision || num.scale > nc.scale {
return nil, ErrOverFlow
}
return []byte(num.String()), err
}
// get num from pool
num := numPool.Get().(*Num)
defer numPool.Put(num)
num.precision, num.scale, num.isNull = nc.precision, nc.scale, false
// check for driver values
if val, ok := src.(driver.Valuer); ok {
if src, err = val.Value(); err != nil {
return nil, err
}
if src == nil {
return nil, nil
}
}
// use scan to convert to numeric
if err = num.Scan(src); err != nil {
return nil, err
}
return []byte(num.String()), err
}
//
// Encoding routines.
//
// encodeNumeric encodes an array of bytes, given by numConverter, to a numeric.
// We expect the precition to be checked at Scan/Value time.
// Money/smallmoney fields are handled here as they are indeed numeric in disguise
func encodeNumeric(e *binary.Encoder, s interface{}, i colType) (err error) {
bytes, ok := s.([]byte)
if !ok {
return errors.New("invalid data type for numeric")
}
num := numPool.Get().(*Num)
defer numPool.Put(num)
num.precision, num.scale = i.precision, i.scale
err = num.Scan(string(bytes[:]))
if err != nil {
return fmt.Errorf("tds: error while scanning array of bytes to numeric: %s", err)
}
// Multiply by the scale before serializing
mul := rPool.Get().(*big.Rat).SetInt(decimalPowers[i.scale])
num.r.Mul(&num.r, mul)
defer rPool.Put(mul)
// no loss of precision will be tolerated.
if !num.r.IsInt() {
return ErrOverFlow
}
// write to the wire as money of numeric, depending on data type
switch i.dataType {
case smallmoneyType:
intVal, _ := num.r.Float64()
e.WriteInt32(int32(intVal))
case moneyNType, moneyType:
intVal := num.r.Num().Int64()
e.WriteUint32(uint32(intVal >> 32))
e.WriteInt32(int32(intVal))
case decimalType, numericType, decimalNType, numericNType:
// length
arraySize := math.Ceil(float64(num.r.Num().BitLen())/8) + 1
e.WriteInt8(int8(arraySize))
// sign
if num.r.Sign() >= 0 {
e.WriteByte(0x00)
} else {
e.WriteByte(0x01)
}
e.Write(num.r.Num().Bytes())
}
err = e.Err()
return err
}
// decodeNumeric decodes a numeric from the wire.
// Returns a big.Rat
func decodeNumeric(e *binary.Encoder, i colType) (interface{}, error) {
sign := e.Int8()
// read all the bytes
bytes := make([]byte, i.bufferSize-1)
e.Read(bytes)
// safety check
if int(i.scale) > len(decimalPowers)-1 {
return nil, ErrOverFlow
}
// cast as a big.Rat
out := new(big.Rat).SetFrac(new(big.Int).SetBytes(bytes), decimalPowers[i.scale])
if sign != 0 {
out = out.Neg(out)
}
return Num{r: *out, precision: i.precision, scale: i.scale}, e.Err()
}