Skip to content

Commit

Permalink
Merge pull request #385 from MichaelS11/ctx
Browse files Browse the repository at this point in the history
Moved ctx into Stmt & Added Blob & Cursor Function Examples
  • Loading branch information
MichaelS11 authored Apr 6, 2020
2 parents 16089af + e3572fc commit 6a5d434
Show file tree
Hide file tree
Showing 9 changed files with 529 additions and 39 deletions.
23 changes: 19 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ import (

// Ping database connection
func (conn *Conn) Ping(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}

done := make(chan struct{})
go conn.ociBreakDone(ctx, done)
result := C.OCIPing(conn.svc, conn.errHandle, C.OCI_DEFAULT)
close(done)

if result == C.OCI_SUCCESS || result == C.OCI_SUCCESS_WITH_INFO {
return nil
}
errorCode, err := conn.ociGetError()
if errorCode == 1010 {
// Older versions of Oracle do not support ping,
// but a response of "ORA-01010: invalid OCI operation" confirms connectivity.
// See https://github.com/rana/ora/issues/224
return nil
}

Expand Down Expand Up @@ -94,10 +98,17 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt

queryP := cString(query)
defer C.free(unsafe.Pointer(queryP))

// statement handle
var stmtTemp *C.OCIStmt
stmt := &stmtTemp

if ctx.Err() != nil {
return nil, ctx.Err()
}

done := make(chan struct{})
go conn.ociBreakDone(ctx, done)
defer func() { close(done) }()

if rv := C.OCIStmtPrepare2(
conn.svc, // service context handle
stmt, // pointer to the statement handle returned
Expand All @@ -112,7 +123,7 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
return nil, conn.getError(rv)
}

return &Stmt{conn: conn, stmt: *stmt}, nil
return &Stmt{conn: conn, stmt: *stmt, ctx: ctx}, nil
}

// Begin starts a transaction
Expand All @@ -122,6 +133,10 @@ func (conn *Conn) Begin() (driver.Tx, error) {

// BeginTx starts a transaction
func (conn *Conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (driver.Tx, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

if conn.transactionMode != C.OCI_TRANS_READWRITE {
// transaction handle
trans, _, err := conn.ociHandleAlloc(C.OCI_HTYPE_TRANS, 0)
Expand Down
4 changes: 4 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ func (connector *Connector) Driver() driver.Driver {

// Connect returns a new database connection
func (connector *Connector) Connect(ctx context.Context) (driver.Conn, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

conn := &Conn{
logger: connector.Logger,
}
Expand Down
178 changes: 177 additions & 1 deletion example_sql_go112_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func Example_sqlCursor() {
// Example shows how to do a cursor select

// For testing, check if database tests are disabled
if oci8.TestDisableDatabase || oci8.TestDisableDestructive {
if oci8.TestDisableDatabase {
fmt.Println(3)
return
}
Expand Down Expand Up @@ -161,3 +161,179 @@ func Example_sqlCursor() {

// output: 3
}

func Example_sqlCursorFunction() {
// Example shows how to do a cursor select from function

// For testing, check if database tests are disabled
if oci8.TestDisableDatabase || oci8.TestDisableDestructive {
fmt.Println(3)
return
}

oci8.Driver.Logger = log.New(os.Stderr, "oci8 ", log.Ldate|log.Ltime|log.LUTC|log.Lshortfile)

var openString string
// [username/[password]@]host[:port][/service_name][?param1=value1&...&paramN=valueN]
if len(oci8.TestUsername) > 0 {
if len(oci8.TestPassword) > 0 {
openString = oci8.TestUsername + "/" + oci8.TestPassword + "@"
} else {
openString = oci8.TestUsername + "@"
}
}
openString += oci8.TestHostValid

// A normal simple Open to localhost would look like:
// db, err := sql.Open("oci8", "127.0.0.1")
// For testing, need to use additional variables
db, err := sql.Open("oci8", openString)
if err != nil {
fmt.Printf("Open error is not nil: %v", err)
return
}
if db == nil {
fmt.Println("db is nil")
return
}

// defer close database
defer func() {
err = db.Close()
if err != nil {
fmt.Println("Close error is not nil:", err)
}
}()

ctx, cancel := context.WithTimeout(context.Background(), 55*time.Second)
err = db.PingContext(ctx)
cancel()
if err != nil {
fmt.Println("PingContext error is not nil:", err)
return
}

// create function
functionName := "E_F_CURSOR_" + oci8.TestTimeString
query := `create or replace function ` + functionName + ` return SYS_REFCURSOR
is
l_cursor SYS_REFCURSOR;
begin
open l_cursor for select 2 from dual union select 3 from dual;
return l_cursor;
end ` + functionName + `;`
ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second)
_, err = db.ExecContext(ctx, query)
cancel()
if err != nil {
fmt.Println("ExecContext error is not nil:", err)
return
}

var rows *sql.Rows
ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second)
defer cancel()
rows, err = db.QueryContext(ctx, "select 1, "+functionName+"() from dual")
if err != nil {
fmt.Println("QueryContext error is not nil:", err)
return
}

// defer close rows
defer func() {
err = rows.Close()
if err != nil {
fmt.Println("Close error is not nil:", err)
}
}()

if !rows.Next() {
fmt.Println("no Next rows")
return
}

var aInt int64
var subRows *sql.Rows
err = rows.Scan(&aInt, &subRows)
if err != nil {
fmt.Println("Scan error is not nil:", err)
return
}

if aInt != 1 {
fmt.Println("aInt != 1")
return
}
if subRows == nil {
fmt.Println("subRows is nil")
return
}

if !subRows.Next() {
fmt.Println("no Next subRows")
return
}

err = subRows.Scan(&aInt)
if err != nil {
fmt.Println("Scan error is not nil:", err)
return
}

if aInt != 2 {
fmt.Println("aInt != 2")
return
}

if !subRows.Next() {
fmt.Println("no Next subRows")
return
}

err = subRows.Scan(&aInt)
if err != nil {
fmt.Println("Scan error is not nil:", err)
return
}

if aInt != 3 {
fmt.Println("aInt != 3")
return
}

if subRows.Next() {
fmt.Println("has Next rows")
return
}

err = subRows.Err()
if err != nil {
fmt.Println("Err error is not nil:", err)
return
}

if rows.Next() {
fmt.Println("has Next rows")
return
}

err = rows.Err()
if err != nil {
fmt.Println("Err error is not nil:", err)
return
}

// drop function
query = "drop function " + functionName
ctx, cancel = context.WithTimeout(context.Background(), 55*time.Second)
_, err = db.ExecContext(ctx, query)
cancel()
if err != nil {
fmt.Println("ExecContext error is not nil:", err)
return
}

fmt.Println(aInt)

// output: 3
}
Loading

0 comments on commit 6a5d434

Please sign in to comment.