-
Notifications
You must be signed in to change notification settings - Fork 12
/
stmt.go
238 lines (191 loc) · 6.36 KB
/
stmt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
package tds
import (
"context"
"database/sql/driver"
"fmt"
"sync"
"sync/atomic"
)
// TODO: batch inserts via Batch/Submit methods.
// Stmt is a prepared statement implementing the driver.Stmt interface
type Stmt struct {
d *dynamic
s *session
ID int64
row *row // parameter values
paramFmts *columns
paramToken token
msgs []messageReaderWriter
converters []driver.ValueConverter
ctx context.Context
values []driver.Value
}
var stmtID int64
// pool for IDs, to eventually reuse them
var idPool = sync.Pool{
New: func() interface{} {
atomic.AddInt64(&stmtID, 1)
return &stmtID
},
}
// newStmt returns a new result set and fetch the headers
func newStmt(ctx context.Context, s *session, query string) (*Stmt, error) {
st := &Stmt{s: s, row: &row{}, ctx: ctx}
if !s.valid {
return st, driver.ErrBadConn
}
params := &columns{msg: newMsg(paramFmtToken), flags: param}
wideParams := &columns{msg: newMsg(paramFmt2Token), flags: wide | param}
st.row = &row{msg: newMsg(paramToken)}
// get a statement number
st.ID = *idPool.Get().(*int64)
st.d = &dynamic{msg: newMsg(dynamic2Token), operation: dynamicPrepare,
name: "gtds" + fmt.Sprintf("%d", st.ID),
statement: "create proc gtds" + fmt.Sprintf("%d", st.ID) + " as " + query}
// send query
err := s.b.send(ctx, normalPacket, st.d)
if err = s.checkErr(err, "tds: Prepare failed", false); err != nil {
return st, err
}
st.s.clearResult()
st.d.statement = ""
// parse parameters info and return status
// the server will spew out a rowfmt, but it's safe to ignore it, it will be resent
for f := s.initState(ctx,
map[token]messageReader{dynamic2Token: st.d,
paramFmtToken: params, paramFmt2Token: wideParams}); f != nil; f = f(s.state) {
}
if err := s.checkErr(s.state.err, "tds: Prepare failed", true); err != nil {
return st, err
}
// assign the expected parameters' values
if len(params.fmts) > 0 {
st.row.columns = params.fmts
st.paramFmts = params
}
if len(wideParams.fmts) > 0 {
st.row.columns = wideParams.fmts
st.paramFmts = wideParams
}
st.values = make([]driver.Value, len(st.row.columns))
// now ready to exec
st.d.operation = dynamicExec
// this query has parameters
if st.paramFmts != nil {
st.d.status |= dynamicHasArgs
// allocate the array containing the valuers and fetch them
st.converters = make([]driver.ValueConverter, len(st.paramFmts.fmts))
for i := 0; i < len(st.paramFmts.fmts); i++ {
st.converters[i] = st.paramFmts.fmts[i].parameterConverter()
}
// cache the messages to send for each exec
st.msgs = []messageReaderWriter{st.d, st.paramFmts, st.row}
} else {
st.d.status &^= dynamicHasArgs
st.msgs = []messageReaderWriter{st.d}
}
return st, nil
}
// send sends the execute to the server
func (st *Stmt) send(ctx context.Context, args []driver.Value) (err error) {
if !st.s.valid {
return driver.ErrBadConn
}
if len(args) != len(st.row.columns) {
return fmt.Errorf("tds: parameter count mismatch, expected %d, got %d",
len(st.row.columns), len(args))
}
st.row.data = args
err = st.s.b.send(ctx, normalPacket, st.msgs[:]...)
st.s.clearResult()
return err
}
// Exec executes a prepared statement.
// Implements the database/sql/Stmt interface
func (st *Stmt) Exec(args []driver.Value) (res driver.Result, err error) {
// send the parameters and the dynamic token
if err = st.send(st.ctx, args[:]); err != nil {
return &emptyResult, st.s.checkErr(err, "tds: send failed while execing", false)
}
// process the server response
rows, err := newRow(st.ctx, st.s)
// discards any row
if err == nil {
err = rows.Close()
}
return &(*st.s.res), st.s.checkErr(err, "tds: Exec failed", true)
}
// ExecContext executes a prepared statement, along with a context.
// Implements the database/sql/Stmt interface
func (st *Stmt) ExecContext(ctx context.Context, namedArgs []driver.NamedValue) (res driver.Result, err error) {
if len(namedArgs) != len(st.values) {
return &emptyResult, fmt.Errorf("tds: ExecContext, invalid arg count")
}
for i := 0; i < len(namedArgs); i++ {
st.values[i] = namedArgs[i].Value
}
// send the parameters and the dynamic token
if err = st.send(ctx, st.values[:]); err != nil {
return &emptyResult, st.s.checkErr(err, "tds: send failed while execing", false)
}
// process the server response
rows, err := newRow(ctx, st.s)
// discards any row
if err == nil {
err = rows.Close()
}
return &(*st.s.res), st.s.checkErr(err, "tds: ExecContext failed", true)
}
// Query executes a prepared statement and returns rows.
func (st *Stmt) Query(args []driver.Value) (driver.Rows, error) {
// send the parameters and the dynamic token
if err := st.send(st.ctx, args); err != nil {
return &emptyRows, st.s.checkErr(err, "tds: send failed while querying", false)
}
rows, err := newRow(st.ctx, st.s)
return rows, st.s.checkErr(err, "tds: QueryContext failed", true)
}
// QueryContext executes a prepared statement and returns rows
func (st *Stmt) QueryContext(ctx context.Context, namedArgs []driver.NamedValue) (driver.Rows, error) {
args := make([]driver.Value, len(namedArgs))
for i, nv := range namedArgs {
args[i] = nv.Value
}
// send the parameters and the dynamic token
if err := st.send(ctx, args); err != nil {
return &emptyRows, st.s.checkErr(err, "tds: send failed while querying", false)
}
rows, err := newRow(ctx, st.s)
return rows, st.s.checkErr(err, "tds: QueryContext failed", true)
}
// NumInput returns the number of expected parameters
func (st Stmt) NumInput() int {
return len(st.row.columns)
}
// Close drops the prepared statement from the database
func (st *Stmt) Close() error {
defer idPool.Put(&st.ID)
st.d.operation = dynamicDealloc
st.d.status = 0
// send message
err := st.s.b.send(st.ctx, normalPacket, st.d)
if err = st.s.checkErr(err, "tds: Close failed", false); err != nil {
return err
}
// get response
// TODO: parse dynamic token to get status
for f := st.s.initState(nil, nil); f != nil; f = f(st.s.state) {
}
if err := st.s.checkErr(st.s.state.err, "tds: close failed", true); err != nil {
return err
}
return nil
}
// ColumnConverter returns converters which check min, max, nullability,
// precision, scale and then convert to a valid sql.Driver value.
func (st Stmt) ColumnConverter(idx int) driver.ValueConverter {
if idx <= len(st.converters) {
return st.converters[idx]
}
return driver.DefaultParameterConverter
}