-
Notifications
You must be signed in to change notification settings - Fork 39
/
api_stmt.go
111 lines (103 loc) · 2.36 KB
/
api_stmt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package db
import (
"database/sql"
lua "github.com/yuin/gopher-lua"
)
type luaStmt struct {
*sql.Stmt
d *sql.DB
}
// Stmt lua db_ud:stmt(query) returns (stmt_ud, err)
func Stmt(L *lua.LState) int {
dbInterface := checkDB(L, 1)
query := L.CheckString(2)
sqlDB := dbInterface.getDB()
s, err := sqlDB.Prepare(query)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
ud := L.NewUserData()
ud.Value = &luaStmt{Stmt: s, d: sqlDB}
L.SetMetatable(ud, L.GetTypeMetatable(`stmt_ud`))
L.Push(ud)
return 1
}
func getSTMTArgs(L *lua.LState) []interface{} {
args := make([]interface{}, 0)
for i := 2; i <= L.GetTop(); i++ {
any := L.CheckAny(i)
switch any.Type() {
case lua.LTNil:
args = append(args, nil)
default:
args = append(args, L.CheckAny(i))
}
}
return args
}
// StmtQuery lua stmt_ud:query(args) returns ({rows = {}, columns = {}}, err)
func StmtQuery(L *lua.LState) int {
ud := L.CheckUserData(1)
s, ok := ud.Value.(*luaStmt)
if !ok {
L.ArgError(1, "must be stmt_ud")
}
args := getSTMTArgs(L)
sqlRows, err := s.Query(args...)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
defer sqlRows.Close()
rows, columns, err := parseRows(sqlRows, L)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
result := L.NewTable()
result.RawSetString(`rows`, rows)
result.RawSetString(`columns`, columns)
L.Push(result)
return 1
}
// StmtExec lua stmt_ud:exec(args) returns ({rows_affected=number, last_insert_id=number}, err)
func StmtExec(L *lua.LState) int {
ud := L.CheckUserData(1)
s, ok := ud.Value.(*luaStmt)
if !ok {
L.ArgError(1, "must be stmt_ud")
}
args := getSTMTArgs(L)
sqlResult, err := s.Exec(args...)
if err != nil {
L.Push(lua.LNil)
L.Push(lua.LString(err.Error()))
return 2
}
result := L.NewTable()
if id, err := sqlResult.LastInsertId(); err == nil {
result.RawSetString(`last_insert_id`, lua.LNumber(id))
}
if aff, err := sqlResult.RowsAffected(); err == nil {
result.RawSetString(`rows_affected`, lua.LNumber(aff))
}
L.Push(result)
return 1
}
// StmtClose lua stmt_ud:close() returns err
func StmtClose(L *lua.LState) int {
ud := L.CheckUserData(1)
s, ok := ud.Value.(*luaStmt)
if !ok {
L.ArgError(1, "must be stmt_ud")
}
if err := s.Close(); err != nil {
L.Push(lua.LString(err.Error()))
return 1
}
return 0
}