Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add sybase 12.5 LastInsertedId() support #77

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions executesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ func (conn *Conn) ExecuteSql(query string, params ...driver.Value) ([]*Result, e
return conn.Exec(sql)
}

// executeSqlSybase125 prepares sybase 12.5 compatible sql statement
// sql statement appends statement for last_inserted_id and rows_affected results
// Note: Due to @@IDENTITY limitations, LastInsertedID will not result in an error if used after an update or delete
// @@IDENTITY returns the last inserted id for the scope of the connection, regardless of the scope that produced it
func (conn *Conn) executeSqlSybase125(query string, params ...driver.Value) ([]*Result, error) {
statement, numParams := query2Statement(query)
_, numParams := query2Statement(query)
if numParams != len(params) {
return nil, fmt.Errorf("Incorrect number of params, expecting %d got %d", numParams, len(params))
}

statement += statusRowSybase125
query += statusRowSybase125
sql := strings.Replace(query, "?", "$bindkey", -1)
re, _ := regexp.Compile(`(?P<bindkey>\$bindkey)`)
matches := re.FindAllSubmatchIndex([]byte(sql), -1)
Expand All @@ -58,9 +62,8 @@ func (conn *Conn) executeSqlSybase125(query string, params ...driver.Value) ([]*
_, escapedValue, _ := go2SqlDataType(params[i])
sql = fmt.Sprintf("%s", strings.Replace(sql, "$bindkey", escapedValue, 1))
}

if numParams == 0 {
sql = fmt.Sprintf("%s", statement)
sql = fmt.Sprintf("%s", query)
}
return conn.Exec(sql)
}
Expand Down
3 changes: 3 additions & 0 deletions mssql_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ func (r *MssqlResult) statusRowValue(columnName string) int64 {
if val, ok := lastResult.Rows[0][idx].(float64); ok {
return int64(val)
}
if val, ok := lastResult.Rows[0][idx].(int32); ok {
return int64(val)
}
}
return -1
}
12 changes: 5 additions & 7 deletions mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package freetds
import (
"database/sql"
"fmt"
"os"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"os"
)

func open(t *testing.T) (*sql.DB, error, bool) {
Expand Down Expand Up @@ -44,7 +44,6 @@ func TestMssqlConnOpenSybase125(t *testing.T) {
c.Close()
}


func TestGoSqlDbQueryRow(t *testing.T) {
db, err, _ := open(t)
defer db.Close()
Expand Down Expand Up @@ -94,10 +93,7 @@ func TestGoSqlPrepareQuery(t *testing.T) {
func TestLastInsertIdRowsAffected(t *testing.T) {
db, _, sybase125 := open(t)
defer db.Close()
if sybase125 {
t.Skip("LastInsertId and RowsEffective not returned in Sybase 12.5")
}
createTestTable(t, db, sybase125,"test_last_insert_id", "")
createTestTable(t, db, sybase125, "test_last_insert_id", "")
r, err := db.Exec("insert into [test_last_insert_id] values(?)", "pero")
assert.Nil(t, err)
assert.NotNil(t, r)
Expand All @@ -122,7 +118,9 @@ func TestLastInsertIdRowsAffected(t *testing.T) {
assert.Nil(t, err)
assert.NotNil(t, r)
id, err = r.LastInsertId()
assert.NotNil(t, err)
if !sybase125 {
assert.NotNil(t, err)
}
ra, err = r.RowsAffected()
assert.Nil(t, err)
assert.EqualValues(t, ra, 2)
Expand Down